зеркало из https://github.com/microsoft/torchgeo.git
Datasets: consistent error messages for missing data (#1714)
* Datasets: consistent error messages for missing data * Fix issues * Increase test coverage * mypy fixes * isort fixes
This commit is contained in:
Родитель
de56a5933c
Коммит
2c65e1d592
|
@ -178,7 +178,7 @@ BioMassters
|
|||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: BioMassters
|
||||
|
||||
|
||||
Cloud Cover Detection
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -464,3 +464,8 @@ Splitting Functions
|
|||
.. autofunction:: random_grid_cell_assignment
|
||||
.. autofunction:: roi_split
|
||||
.. autofunction:: time_series_split
|
||||
|
||||
Errors
|
||||
------
|
||||
|
||||
.. autoclass:: DatasetNotFoundError
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ADVANCE
|
||||
from torchgeo.datasets import ADVANCE, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -68,7 +68,7 @@ class TestADVANCE:
|
|||
ADVANCE(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ADVANCE(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
|
|
|
@ -15,6 +15,7 @@ from rasterio.crs import CRS
|
|||
import torchgeo
|
||||
from torchgeo.datasets import (
|
||||
AbovegroundLiveWoodyBiomassDensity,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
@ -53,7 +54,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
|
|||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_no_dataset(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
|
||||
|
||||
def test_already_downloaded(
|
||||
|
|
|
@ -11,7 +11,13 @@ import torch
|
|||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
AsterGDEM,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestAsterGDEM:
|
||||
|
@ -26,7 +32,7 @@ class TestAsterGDEM:
|
|||
def test_datasetmissing(self, tmp_path: Path) -> None:
|
||||
shutil.rmtree(tmp_path)
|
||||
os.makedirs(tmp_path)
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
AsterGDEM(str(tmp_path))
|
||||
|
||||
def test_getitem(self, dataset: AsterGDEM) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import BeninSmallHolderCashews
|
||||
from torchgeo.datasets import BeninSmallHolderCashews, DatasetNotFoundError
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -73,7 +73,7 @@ class TestBeninSmallHolderCashews:
|
|||
BeninSmallHolderCashews(root=dataset.root, download=True, api_key="")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
BeninSmallHolderCashews(str(tmp_path))
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BigEarthNet
|
||||
from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -134,10 +134,7 @@ class TestBigEarthNet:
|
|||
)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
BigEarthNet(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: BigEarthNet) -> None:
|
||||
|
|
|
@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import pytest
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datasets import BioMassters
|
||||
from torchgeo.datasets import BioMassters, DatasetNotFoundError
|
||||
|
||||
|
||||
class TestBioMassters:
|
||||
|
@ -36,8 +36,7 @@ class TestBioMassters:
|
|||
BioMassters(dataset.root, sensors=["S3"])
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
match = "Dataset not found"
|
||||
with pytest.raises(RuntimeError, match=match):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
BioMassters(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: BioMassters) -> None:
|
||||
|
|
|
@ -16,6 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
CanadianBuildingFootprints,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
@ -75,7 +76,7 @@ class TestCanadianBuildingFootprints:
|
|||
dataset.plot(x, suptitle="Prediction")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CanadianBuildingFootprints(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
|
|
|
@ -15,7 +15,13 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
CDL,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -111,7 +117,7 @@ class TestCDL:
|
|||
plt.close()
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CDL(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: CDL) -> None:
|
||||
|
|
|
@ -18,6 +18,7 @@ from torchgeo.datasets import (
|
|||
BoundingBox,
|
||||
Chesapeake13,
|
||||
ChesapeakeCVPR,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
@ -70,7 +71,7 @@ class TestChesapeake13:
|
|||
Chesapeake13(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Chesapeake13(str(tmp_path), checksum=True)
|
||||
|
||||
def test_plot(self, dataset: Chesapeake13) -> None:
|
||||
|
@ -193,7 +194,7 @@ class TestChesapeakeCVPR:
|
|||
ChesapeakeCVPR(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ChesapeakeCVPR(str(tmp_path), checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import CloudCoverDetection
|
||||
from torchgeo.datasets import CloudCoverDetection, DatasetNotFoundError
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -83,7 +83,7 @@ class TestCloudCoverDetection:
|
|||
CloudCoverDetection(root=dataset.root, split="test", download=True, api_key="")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CloudCoverDetection(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: CloudCoverDetection) -> None:
|
||||
|
|
|
@ -12,7 +12,12 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import CMSGlobalMangroveCanopy, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
CMSGlobalMangroveCanopy,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -45,7 +50,7 @@ class TestCMSGlobalMangroveCanopy:
|
|||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_no_dataset(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CMSGlobalMangroveCanopy(str(tmp_path))
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -14,8 +14,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import COWCCounting, COWCDetection
|
||||
from torchgeo.datasets.cowc import COWC
|
||||
from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -78,7 +77,7 @@ class TestCOWCCounting:
|
|||
COWCCounting(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
COWCCounting(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: COWCCounting) -> None:
|
||||
|
@ -142,7 +141,7 @@ class TestCOWCDetection:
|
|||
COWCDetection(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
COWCDetection(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: COWCDetection) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import CV4AKenyaCropType
|
||||
from torchgeo.datasets import CV4AKenyaCropType, DatasetNotFoundError
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -84,7 +84,7 @@ class TestCV4AKenyaCropType:
|
|||
CV4AKenyaCropType(root=dataset.root, download=True, api_key="")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CV4AKenyaCropType(str(tmp_path))
|
||||
|
||||
def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import TropicalCyclone
|
||||
from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -80,7 +80,7 @@ class TestTropicalCyclone:
|
|||
TropicalCyclone(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
TropicalCyclone(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: TropicalCyclone) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import DeepGlobeLandCover
|
||||
from torchgeo.datasets import DatasetNotFoundError, DeepGlobeLandCover
|
||||
|
||||
|
||||
class TestDeepGlobeLandCover:
|
||||
|
@ -55,12 +55,7 @@ class TestDeepGlobeLandCover:
|
|||
DeepGlobeLandCover(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Dataset not found in `root`, either"
|
||||
+ " specify a different `root` directory or manually download"
|
||||
+ " the dataset to this directory.",
|
||||
):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
DeepGlobeLandCover(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: DeepGlobeLandCover) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import DFC2022
|
||||
from torchgeo.datasets import DFC2022, DatasetNotFoundError
|
||||
|
||||
|
||||
class TestDFC2022:
|
||||
|
@ -74,7 +74,7 @@ class TestDFC2022:
|
|||
DFC2022(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
DFC2022(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: DFC2022) -> None:
|
||||
|
|
|
@ -6,7 +6,13 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datasets import BoundingBox, EDDMapS, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
EDDMapS,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestEDDMapS:
|
||||
|
@ -31,7 +37,7 @@ class TestEDDMapS:
|
|||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
EDDMapS(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: EDDMapS) -> None:
|
||||
|
|
|
@ -16,6 +16,7 @@ from rasterio.crs import CRS
|
|||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
EnviroAtlas,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
|
@ -88,7 +89,7 @@ class TestEnviroAtlas:
|
|||
EnviroAtlas(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
EnviroAtlas(str(tmp_path), checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
|
||||
|
|
|
@ -13,7 +13,13 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, Esri2020, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
Esri2020,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -60,7 +66,7 @@ class TestEsri2020:
|
|||
Esri2020(str(tmp_path))
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Esri2020(str(tmp_path), checksum=True)
|
||||
|
||||
def test_and(self, dataset: Esri2020) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ETCI2021
|
||||
from torchgeo.datasets import ETCI2021, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -77,7 +77,7 @@ class TestETCI2021:
|
|||
ETCI2021(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ETCI2021(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: ETCI2021) -> None:
|
||||
|
|
|
@ -12,7 +12,13 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import EUDEM, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
EUDEM,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestEUDEM:
|
||||
|
@ -41,7 +47,7 @@ class TestEUDEM:
|
|||
def test_no_dataset(self, tmp_path: Path) -> None:
|
||||
shutil.rmtree(tmp_path)
|
||||
os.makedirs(tmp_path)
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
EUDEM(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import EuroSAT, EuroSAT100
|
||||
from torchgeo.datasets import DatasetNotFoundError, EuroSAT, EuroSAT100
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -92,10 +92,7 @@ class TestEuroSAT:
|
|||
EuroSAT(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
EuroSAT(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: EuroSAT) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import FAIR1M
|
||||
from torchgeo.datasets import FAIR1M, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -120,7 +120,7 @@ class TestFAIR1M:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
|
||||
shutil.rmtree(str(tmp_path))
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split)
|
||||
|
||||
def test_plot(self, dataset: FAIR1M) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import FireRisk
|
||||
from torchgeo.datasets import DatasetNotFoundError, FireRisk
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -56,7 +56,7 @@ class TestFireRisk:
|
|||
FireRisk(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
FireRisk(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: FireRisk) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ForestDamage
|
||||
from torchgeo.datasets import DatasetNotFoundError, ForestDamage
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -66,7 +66,7 @@ class TestForestDamage:
|
|||
ForestDamage(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ForestDamage(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: ForestDamage) -> None:
|
||||
|
|
|
@ -6,7 +6,13 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from torchgeo.datasets import GBIF, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
GBIF,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestGBIF:
|
||||
|
@ -31,7 +37,7 @@ class TestGBIF:
|
|||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
GBIF(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: GBIF) -> None:
|
||||
|
|
|
@ -16,6 +16,7 @@ from torch.utils.data import ConcatDataset
|
|||
from torchgeo.datasets import (
|
||||
NAIP,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
GeoDataset,
|
||||
IntersectionDataset,
|
||||
NonGeoClassificationDataset,
|
||||
|
@ -262,7 +263,7 @@ class TestRasterDataset:
|
|||
sentinel[query]
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
RasterDataset(str(tmp_path))
|
||||
|
||||
def test_no_all_bands(self) -> None:
|
||||
|
@ -327,7 +328,7 @@ class TestVectorDataset:
|
|||
dataset[query]
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
VectorDataset(str(tmp_path))
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import GID15
|
||||
from torchgeo.datasets import GID15, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -58,7 +58,7 @@ class TestGID15:
|
|||
GID15(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
GID15(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: GID15) -> None:
|
||||
|
|
|
@ -14,6 +14,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
GlobBiomass,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
|
@ -50,7 +51,7 @@ class TestGlobBiomass:
|
|||
GlobBiomass(dataset.paths)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
GlobBiomass(str(tmp_path), checksum=True)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import IDTReeS
|
||||
from torchgeo.datasets import DatasetNotFoundError, IDTReeS
|
||||
|
||||
pytest.importorskip("laspy", minversion="2")
|
||||
|
||||
|
@ -91,10 +91,7 @@ class TestIDTReeS:
|
|||
IDTReeS(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
IDTReeS(str(tmp_path))
|
||||
|
||||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
INaturalist,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
|
@ -36,7 +37,7 @@ class TestINaturalist:
|
|||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
INaturalist(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: INaturalist) -> None:
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import InriaAerialImageLabeling
|
||||
from torchgeo.datasets import DatasetNotFoundError, InriaAerialImageLabeling
|
||||
|
||||
|
||||
class TestInriaAerialImageLabeling:
|
||||
|
@ -49,7 +49,7 @@ class TestInriaAerialImageLabeling:
|
|||
InriaAerialImageLabeling(root=dataset.root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: str) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
InriaAerialImageLabeling(str(tmp_path))
|
||||
|
||||
def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:
|
||||
|
|
|
@ -14,7 +14,13 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, IntersectionDataset, L7Irish, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
L7Irish,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -68,7 +74,7 @@ class TestL7Irish:
|
|||
L7Irish(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
L7Irish(str(tmp_path))
|
||||
|
||||
def test_plot_prediction(self, dataset: L7Irish) -> None:
|
||||
|
|
|
@ -14,7 +14,13 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, IntersectionDataset, L8Biome, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
L8Biome,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -68,7 +74,7 @@ class TestL8Biome:
|
|||
L8Biome(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
L8Biome(str(tmp_path))
|
||||
|
||||
def test_plot_prediction(self, dataset: L8Biome) -> None:
|
||||
|
|
|
@ -14,7 +14,12 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
LandCoverAI,
|
||||
LandCoverAIGeo,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -49,7 +54,7 @@ class TestLandCoverAIGeo:
|
|||
LandCoverAIGeo(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
LandCoverAIGeo(str(tmp_path))
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
|
||||
|
@ -115,7 +120,7 @@ class TestLandCoverAI:
|
|||
LandCoverAI(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
LandCoverAI(str(tmp_path))
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
|
|
|
@ -12,7 +12,13 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
Landsat8,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestLandsat8:
|
||||
|
@ -60,7 +66,7 @@ class TestLandsat8:
|
|||
ds.plot(x)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Landsat8(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: Landsat8) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LEVIRCDPlus
|
||||
from torchgeo.datasets import DatasetNotFoundError, LEVIRCDPlus
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -55,7 +55,7 @@ class TestLEVIRCDPlus:
|
|||
LEVIRCDPlus(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
LEVIRCDPlus(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: LEVIRCDPlus) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import LoveDA
|
||||
from torchgeo.datasets import DatasetNotFoundError, LoveDA
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -83,9 +83,7 @@ class TestLoveDA:
|
|||
LoveDA(scene=["garden"])
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Dataset not found at root directory or corrupted."
|
||||
):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
LoveDA(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: LoveDA) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import MapInWild
|
||||
from torchgeo.datasets import DatasetNotFoundError, MapInWild
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -97,7 +97,7 @@ class TestMapInWild:
|
|||
MapInWild(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
MapInWild(root=str(tmp_path))
|
||||
|
||||
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -11,7 +11,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from torchgeo.datasets import MillionAID
|
||||
from torchgeo.datasets import DatasetNotFoundError, MillionAID
|
||||
|
||||
|
||||
class TestMillionAID:
|
||||
|
@ -38,7 +38,7 @@ class TestMillionAID:
|
|||
assert len(dataset) == 2
|
||||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
MillionAID(str(tmp_path))
|
||||
|
||||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -10,7 +10,13 @@ import torch
|
|||
import torch.nn as nn
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
NAIP,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
class TestNAIP:
|
||||
|
@ -41,7 +47,7 @@ class TestNAIP:
|
|||
plt.close()
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No NAIP data was found in "):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
NAIP(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: NAIP) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import NASAMarineDebris
|
||||
from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -90,10 +90,7 @@ class TestNASAMarineDebris:
|
|||
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
NASAMarineDebris(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: NASAMarineDebris) -> None:
|
||||
|
|
|
@ -13,7 +13,13 @@ from pytest import MonkeyPatch
|
|||
from rasterio.crs import CRS
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset
|
||||
from torchgeo.datasets import (
|
||||
NLCD,
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
UnionDataset,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -107,7 +113,7 @@ class TestNLCD:
|
|||
plt.close()
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
NLCD(str(tmp_path))
|
||||
|
||||
def test_invalid_query(self, dataset: NLCD) -> None:
|
||||
|
|
|
@ -16,6 +16,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
OpenBuildings,
|
||||
UnionDataset,
|
||||
|
@ -52,16 +53,9 @@ class TestOpenBuildings:
|
|||
assert isinstance(x["crs"], CRS)
|
||||
assert isinstance(x["mask"], torch.Tensor)
|
||||
|
||||
def test_no_building_data_found(self, tmp_path: Path) -> None:
|
||||
false_root = os.path.join(tmp_path, "empty")
|
||||
os.makedirs(false_root)
|
||||
shutil.copy(
|
||||
os.path.join("tests", "data", "openbuildings", "tiles.geojson"), false_root
|
||||
)
|
||||
with pytest.raises(
|
||||
RuntimeError, match="have manually downloaded the dataset as suggested "
|
||||
):
|
||||
OpenBuildings(false_root)
|
||||
def test_not_download(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
OpenBuildings(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f:
|
||||
|
@ -69,12 +63,6 @@ class TestOpenBuildings:
|
|||
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
|
||||
OpenBuildings(dataset.paths, checksum=True)
|
||||
|
||||
def test_no_meta_data_found(self, tmp_path: Path) -> None:
|
||||
false_root = os.path.join(tmp_path, "empty")
|
||||
os.makedirs(false_root)
|
||||
with pytest.raises(FileNotFoundError, match="Meta data file"):
|
||||
OpenBuildings(false_root)
|
||||
|
||||
def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
|
||||
# change meta data to another 'title_url' so that there is no match found
|
||||
with open(os.path.join(tmp_path, "tiles.geojson")) as f:
|
||||
|
@ -84,7 +72,7 @@ class TestOpenBuildings:
|
|||
with open(os.path.join(tmp_path, "tiles.geojson"), "w") as f:
|
||||
json.dump(content, f)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="data was found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
OpenBuildings(dataset.paths)
|
||||
|
||||
def test_getitem(self, dataset: OpenBuildings) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import OSCD
|
||||
from torchgeo.datasets import OSCD, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -107,7 +107,7 @@ class TestOSCD:
|
|||
OSCD(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
OSCD(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: OSCD) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import PASTIS
|
||||
from torchgeo.datasets import PASTIS, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -80,7 +80,7 @@ class TestPASTIS:
|
|||
PASTIS(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
PASTIS(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import PatternNet
|
||||
from torchgeo.datasets import DatasetNotFoundError, PatternNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -52,10 +52,7 @@ class TestPatternNet:
|
|||
PatternNet(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
PatternNet(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: PatternNet) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import Potsdam2D
|
||||
from torchgeo.datasets import DatasetNotFoundError, Potsdam2D
|
||||
|
||||
|
||||
class TestPotsdam2D:
|
||||
|
@ -60,7 +60,7 @@ class TestPotsdam2D:
|
|||
Potsdam2D(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Potsdam2D(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: Potsdam2D) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ReforesTree
|
||||
from torchgeo.datasets import DatasetNotFoundError, ReforesTree
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
|
@ -66,7 +66,7 @@ class TestReforesTree:
|
|||
ReforesTree(root=str(tmp_path), checksum=True)
|
||||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ReforesTree(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: ReforesTree) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import RESISC45
|
||||
from torchgeo.datasets import RESISC45, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -78,10 +78,7 @@ class TestRESISC45:
|
|||
RESISC45(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
RESISC45(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: RESISC45) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import RwandaFieldBoundary
|
||||
from torchgeo.datasets import DatasetNotFoundError, RwandaFieldBoundary
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -87,7 +87,7 @@ class TestRwandaFieldBoundary:
|
|||
RwandaFieldBoundary(root=dataset.root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
RwandaFieldBoundary(str(tmp_path))
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SeasoNet
|
||||
from torchgeo.datasets import DatasetNotFoundError, SeasoNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -147,7 +147,7 @@ class TestSeasoNet:
|
|||
SeasoNet(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SeasoNet(str(tmp_path), download=False)
|
||||
|
||||
def test_out_of_bounds(self, dataset: SeasoNet) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SeasonalContrastS2
|
||||
from torchgeo.datasets import DatasetNotFoundError, SeasonalContrastS2
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -98,7 +98,7 @@ class TestSeasonalContrastS2:
|
|||
SeasonalContrastS2(bands=["A1steaksauce"])
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SeasonalContrastS2(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SeasonalContrastS2) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import SEN12MS
|
||||
from torchgeo.datasets import SEN12MS, DatasetNotFoundError
|
||||
|
||||
|
||||
class TestSEN12MS:
|
||||
|
@ -65,10 +65,10 @@ class TestSEN12MS:
|
|||
SEN12MS(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SEN12MS(str(tmp_path), checksum=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SEN12MS(str(tmp_path), checksum=False)
|
||||
|
||||
def test_check_integrity_light(self) -> None:
|
||||
|
|
|
@ -13,6 +13,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from torchgeo.datasets import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
IntersectionDataset,
|
||||
Sentinel1,
|
||||
Sentinel2,
|
||||
|
@ -64,7 +65,7 @@ class TestSentinel1:
|
|||
plt.close()
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No Sentinel1 data was found in "):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Sentinel1(str(tmp_path))
|
||||
|
||||
def test_empty_bands(self) -> None:
|
||||
|
@ -123,7 +124,7 @@ class TestSentinel2:
|
|||
assert isinstance(ds, UnionDataset)
|
||||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Sentinel2(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: Sentinel2) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SKIPPD
|
||||
from torchgeo.datasets import SKIPPD, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip("h5py", minversion="3")
|
||||
|
||||
|
@ -105,7 +105,7 @@ class TestSKIPPD:
|
|||
SKIPPD(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SKIPPD(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SKIPPD) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import So2Sat
|
||||
from torchgeo.datasets import DatasetNotFoundError, So2Sat
|
||||
|
||||
pytest.importorskip("h5py", minversion="3")
|
||||
|
||||
|
@ -70,7 +70,7 @@ class TestSo2Sat:
|
|||
So2Sat(bands=("OK", "BK"))
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
So2Sat(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: So2Sat) -> None:
|
||||
|
|
|
@ -14,6 +14,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import (
|
||||
DatasetNotFoundError,
|
||||
SpaceNet1,
|
||||
SpaceNet2,
|
||||
SpaceNet3,
|
||||
|
@ -91,7 +92,7 @@ class TestSpaceNet1:
|
|||
SpaceNet1(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet1(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SpaceNet1) -> None:
|
||||
|
@ -147,7 +148,7 @@ class TestSpaceNet2:
|
|||
SpaceNet2(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet2(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
|
||||
|
@ -207,7 +208,7 @@ class TestSpaceNet3:
|
|||
SpaceNet3(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet3(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
|
||||
|
@ -271,7 +272,7 @@ class TestSpaceNet4:
|
|||
SpaceNet4(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet4(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
|
@ -333,7 +334,7 @@ class TestSpaceNet5:
|
|||
SpaceNet5(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet5(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
|
||||
|
@ -427,7 +428,7 @@ class TestSpaceNet7:
|
|||
SpaceNet7(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SpaceNet7(str(tmp_path))
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo
|
||||
from torchgeo.datasets import SSL4EOL, SSL4EOS12
|
||||
from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -94,7 +94,7 @@ class TestSSL4EOL:
|
|||
SSL4EOL(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SSL4EOL(str(tmp_path))
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
|
@ -155,7 +155,7 @@ class TestSSL4EOS12:
|
|||
SSL4EOS12(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SSL4EOS12(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SSL4EOS12) -> None:
|
||||
|
|
|
@ -16,7 +16,13 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark
|
||||
from torchgeo.datasets import (
|
||||
CDL,
|
||||
NLCD,
|
||||
DatasetNotFoundError,
|
||||
RasterDataset,
|
||||
SSL4EOLBenchmark,
|
||||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -137,7 +143,7 @@ class TestSSL4EOLBenchmark:
|
|||
SSL4EOLBenchmark(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SSL4EOLBenchmark(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SSL4EOLBenchmark) -> None:
|
||||
|
|
|
@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import SustainBenchCropYield
|
||||
from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -71,7 +71,7 @@ class TestSustainBenchCropYield:
|
|||
SustainBenchCropYield(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
SustainBenchCropYield(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: SustainBenchCropYield) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import UCMerced
|
||||
from torchgeo.datasets import DatasetNotFoundError, UCMerced
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -81,10 +81,7 @@ class TestUCMerced:
|
|||
UCMerced(root=str(tmp_path), download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
UCMerced(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: UCMerced) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import USAVars
|
||||
from torchgeo.datasets import DatasetNotFoundError, USAVars
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
|
@ -129,7 +129,7 @@ class TestUSAVars:
|
|||
USAVars(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
USAVars(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: USAVars) -> None:
|
||||
|
|
|
@ -18,10 +18,12 @@ import pytest
|
|||
import torch
|
||||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets.utils import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
download_and_extract_archive,
|
||||
|
@ -36,6 +38,52 @@ from torchgeo.datasets.utils import (
|
|||
)
|
||||
|
||||
|
||||
class TestDatasetNotFoundError:
|
||||
def test_none(self) -> None:
|
||||
ds: Dataset[Any] = Dataset()
|
||||
match = "Dataset not found."
|
||||
with pytest.raises(DatasetNotFoundError, match=match):
|
||||
raise DatasetNotFoundError(ds)
|
||||
|
||||
def test_root(self) -> None:
|
||||
ds: Dataset[Any] = Dataset()
|
||||
ds.root = "foo" # type: ignore[attr-defined]
|
||||
match = "Dataset not found in `root='foo'` and cannot be automatically "
|
||||
match += "downloaded, either specify a different `root` or manually "
|
||||
match += "download the dataset."
|
||||
with pytest.raises(DatasetNotFoundError, match=match):
|
||||
raise DatasetNotFoundError(ds)
|
||||
|
||||
def test_paths(self) -> None:
|
||||
ds: Dataset[Any] = Dataset()
|
||||
ds.paths = "foo" # type: ignore[attr-defined]
|
||||
match = "Dataset not found in `paths='foo'` and cannot be automatically "
|
||||
match += "downloaded, either specify a different `paths` or manually "
|
||||
match += "download the dataset."
|
||||
with pytest.raises(DatasetNotFoundError, match=match):
|
||||
raise DatasetNotFoundError(ds)
|
||||
|
||||
def test_root_download(self) -> None:
|
||||
ds: Dataset[Any] = Dataset()
|
||||
ds.root = "foo" # type: ignore[attr-defined]
|
||||
ds.download = False # type: ignore[attr-defined]
|
||||
match = "Dataset not found in `root='foo'` and `download=False`, either "
|
||||
match += "specify a different `root` or use `download=True` to automatically "
|
||||
match += "download the dataset."
|
||||
with pytest.raises(DatasetNotFoundError, match=match):
|
||||
raise DatasetNotFoundError(ds)
|
||||
|
||||
def test_paths_download(self) -> None:
|
||||
ds: Dataset[Any] = Dataset()
|
||||
ds.paths = "foo" # type: ignore[attr-defined]
|
||||
ds.download = False # type: ignore[attr-defined]
|
||||
match = "Dataset not found in `paths='foo'` and `download=False`, either "
|
||||
match += "specify a different `paths` or use `download=True` to automatically "
|
||||
match += "download the dataset."
|
||||
with pytest.raises(DatasetNotFoundError, match=match):
|
||||
raise DatasetNotFoundError(ds)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
@ -48,7 +96,7 @@ def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
|
|||
monkeypatch.setattr(builtins, "__import__", mocked_import)
|
||||
|
||||
|
||||
class Dataset:
|
||||
class MLHubDataset:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
|
||||
|
@ -66,8 +114,8 @@ class Collection:
|
|||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset()
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset:
|
||||
return MLHubDataset()
|
||||
|
||||
|
||||
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import Vaihingen2D
|
||||
from torchgeo.datasets import DatasetNotFoundError, Vaihingen2D
|
||||
|
||||
|
||||
class TestVaihingen2D:
|
||||
|
@ -69,7 +69,7 @@ class TestVaihingen2D:
|
|||
Vaihingen2D(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
Vaihingen2D(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: Vaihingen2D) -> None:
|
||||
|
|
|
@ -16,7 +16,7 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import VHR10
|
||||
from torchgeo.datasets import VHR10, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip("pycocotools")
|
||||
|
||||
|
@ -90,7 +90,7 @@ class TestVHR10:
|
|||
VHR10(split="train")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
VHR10(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import WesternUSALiveFuelMoisture
|
||||
from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture
|
||||
|
||||
|
||||
class Collection:
|
||||
|
@ -65,7 +65,7 @@ class TestWesternUSALiveFuelMoisture:
|
|||
WesternUSALiveFuelMoisture(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
WesternUSALiveFuelMoisture(str(tmp_path))
|
||||
|
||||
def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:
|
||||
|
|
|
@ -12,7 +12,7 @@ import torch.nn as nn
|
|||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import XView2
|
||||
from torchgeo.datasets import DatasetNotFoundError, XView2
|
||||
|
||||
|
||||
class TestXView2:
|
||||
|
@ -80,7 +80,7 @@ class TestXView2:
|
|||
XView2(split="foo")
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
XView2(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: XView2) -> None:
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import ZueriCrop
|
||||
from torchgeo.datasets import DatasetNotFoundError, ZueriCrop
|
||||
|
||||
pytest.importorskip("h5py", minversion="3")
|
||||
|
||||
|
@ -79,10 +79,7 @@ class TestZueriCrop:
|
|||
ZueriCrop(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
err = "Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
ZueriCrop(str(tmp_path))
|
||||
|
||||
def test_mock_missing_module(
|
||||
|
|
|
@ -116,6 +116,7 @@ from .ucmerced import UCMerced
|
|||
from .usavars import USAVars
|
||||
from .utils import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
concat_samples,
|
||||
merge_samples,
|
||||
stack_samples,
|
||||
|
@ -253,4 +254,6 @@ __all__ = (
|
|||
"random_grid_cell_assignment",
|
||||
"roi_split",
|
||||
"time_series_split",
|
||||
# Errors
|
||||
"DatasetNotFoundError",
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_and_extract_archive
|
||||
from .utils import DatasetNotFoundError, download_and_extract_archive
|
||||
|
||||
|
||||
class ADVANCE(NonGeoDataset):
|
||||
|
@ -101,8 +101,7 @@ class ADVANCE(NonGeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
@ -112,10 +111,7 @@ class ADVANCE(NonGeoDataset):
|
|||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self.files = self._load_files(self.root)
|
||||
self.classes = sorted({f["cls"] for f in self.files})
|
||||
|
@ -218,11 +214,7 @@ class ADVANCE(NonGeoDataset):
|
|||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum of split.py does not match
|
||||
"""
|
||||
"""Download the dataset and extract it."""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
|
|
@ -13,7 +13,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import download_url
|
||||
from .utils import DatasetNotFoundError, download_url
|
||||
|
||||
|
||||
class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
||||
|
@ -77,7 +77,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -90,22 +90,14 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted files already exist
|
||||
if self.files:
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
|
|
@ -10,6 +10,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import DatasetNotFoundError
|
||||
|
||||
|
||||
class AsterGDEM(RasterDataset):
|
||||
|
@ -65,8 +66,7 @@ class AsterGDEM(RasterDataset):
|
|||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if dataset is missing
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -78,20 +78,12 @@ class AsterGDEM(RasterDataset):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted files already exists
|
||||
if self.files:
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` "
|
||||
"either specify a different `root` directory or make sure you "
|
||||
"have manually downloaded dataset tiles as suggested in the documentation."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
|
@ -18,7 +18,12 @@ from rasterio.crs import CRS
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_radiant_mlhub_collection,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
|
@ -198,7 +203,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
verbose: if True, print messages when new tiles are loaded
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
|
||||
|
@ -214,10 +219,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
self._download(api_key)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Calculate the indices that we will use over all tiles
|
||||
self.chips_metadata = []
|
||||
|
|
|
@ -17,7 +17,12 @@ from rasterio.enums import Resampling
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive, sort_sentinel2_bands
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
download_url,
|
||||
extract_archive,
|
||||
sort_sentinel2_bands,
|
||||
)
|
||||
|
||||
|
||||
class BigEarthNet(NonGeoDataset):
|
||||
|
@ -285,6 +290,9 @@ class BigEarthNet(NonGeoDataset):
|
|||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.splits_metadata
|
||||
assert bands in ["s1", "s2", "all"]
|
||||
|
@ -434,11 +442,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
return target
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
keys = ["s1", "s2"] if self.bands == "all" else [self.bands]
|
||||
urls = [self.metadata[k]["url"] for k in keys]
|
||||
md5s = [self.metadata[k]["md5"] for k in keys]
|
||||
|
@ -478,11 +482,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download and extract the dataset
|
||||
for url, filename, md5 in zip(urls, filenames, md5s):
|
||||
|
|
|
@ -16,7 +16,7 @@ from matplotlib.figure import Figure
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import percentile_normalization
|
||||
from .utils import DatasetNotFoundError, percentile_normalization
|
||||
|
||||
|
||||
class BioMassters(NonGeoDataset):
|
||||
|
@ -75,8 +75,9 @@ class BioMassters(NonGeoDataset):
|
|||
as_time_series: whether or not to return all available
|
||||
time-steps or just a single one for a given target location
|
||||
|
||||
RuntimeError:
|
||||
Raises:
|
||||
AssertionError: if ``split`` or ``sensors`` is invalid
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
"""
|
||||
self.root = root
|
||||
|
||||
|
@ -212,7 +213,7 @@ class BioMassters(NonGeoDataset):
|
|||
if all(exists):
|
||||
return
|
||||
|
||||
raise RuntimeError(f"Dataset not found in `root={self.root}`.")
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
|
@ -12,7 +12,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import VectorDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class CanadianBuildingFootprints(VectorDataset):
|
||||
|
@ -81,9 +81,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` and data is not found, or
|
||||
``checksum=True`` and checksums don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -95,10 +93,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
super().__init__(paths, crs, res, transforms)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class CDL(RasterDataset):
|
||||
|
@ -234,8 +234,7 @@ class CDL(RasterDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``years`` or ``classes`` are invalid
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionadded:: 0.5
|
||||
The *years* and *classes* parameters.
|
||||
|
@ -286,11 +285,7 @@ class CDL(RasterDataset):
|
|||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted files already exist
|
||||
if self.files:
|
||||
return
|
||||
|
@ -313,11 +308,7 @@ class CDL(RasterDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
|
|
@ -24,7 +24,7 @@ from rasterio.crs import CRS
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
|
@ -112,8 +112,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -138,11 +137,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted file already exists
|
||||
if self.files:
|
||||
return
|
||||
|
@ -155,11 +150,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
@ -562,9 +553,8 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
AssertionError: if ``splits`` or ``layers`` are not valid
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
for split in splits:
|
||||
assert split in self.splits
|
||||
|
@ -694,11 +684,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
|
||||
def exists(filename: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, filename))
|
||||
|
@ -719,11 +705,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
|
|
@ -16,7 +16,12 @@ from matplotlib.figure import Figure
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_radiant_mlhub_collection,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
|
@ -123,7 +128,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self.root = root
|
||||
self.split = split
|
||||
|
@ -137,10 +142,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
self._download(api_key)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self.chip_paths = self._load_collections()
|
||||
|
||||
|
@ -331,9 +333,6 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
|
|
|
@ -11,7 +11,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import check_integrity, extract_archive
|
||||
from .utils import DatasetNotFoundError, check_integrity, extract_archive
|
||||
|
||||
|
||||
class CMSGlobalMangroveCanopy(RasterDataset):
|
||||
|
@ -192,9 +192,8 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if dataset is missing or checksum fails
|
||||
AssertionError: if country or measurement arg are not str or invalid
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -225,11 +224,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted files already exist
|
||||
if self.files:
|
||||
return
|
||||
|
@ -243,11 +238,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
self._extract()
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` "
|
||||
"either specify a different `root` directory or make sure you "
|
||||
"have manually downloaded the dataset as instructed in the documentation."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
|
|
|
@ -16,7 +16,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class COWC(NonGeoDataset, abc.ABC):
|
||||
|
@ -81,8 +81,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in ["train", "test"]
|
||||
|
||||
|
@ -95,10 +94,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self.images = []
|
||||
self.targets = []
|
||||
|
|
|
@ -16,7 +16,12 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_radiant_mlhub_collection,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
# TODO: read geospatial information from stac.json files
|
||||
|
@ -141,7 +146,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
verbose: if True, print messages when new tiles are loaded
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self._validate_bands(bands)
|
||||
|
||||
|
@ -157,10 +162,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
self._download(api_key)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Calculate the indices that we will use over all tiles
|
||||
self.chips_metadata = []
|
||||
|
@ -390,9 +392,6 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
|
|
|
@ -16,7 +16,12 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_radiant_mlhub_collection,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
class TropicalCyclone(NonGeoDataset):
|
||||
|
@ -86,7 +91,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.md5s
|
||||
|
||||
|
@ -99,10 +104,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
self._download(api_key)
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
output_dir = "_".join([self.collection_id, split, "source"])
|
||||
filename = os.path.join(root, output_dir, "collection.json")
|
||||
|
@ -206,9 +208,6 @@ class TropicalCyclone(NonGeoDataset):
|
|||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
|
|
|
@ -15,6 +15,7 @@ from torch import Tensor
|
|||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
draw_semantic_segmentation_masks,
|
||||
extract_archive,
|
||||
|
@ -102,6 +103,9 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
"""
|
||||
assert split in self.splits
|
||||
self.root = root
|
||||
|
@ -195,11 +199,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
return tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if checksum fails or the dataset is not downloaded
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
if os.path.exists(os.path.join(self.root, self.data_root)):
|
||||
return
|
||||
|
@ -213,11 +213,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
extract_archive(filepath)
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root`, either specify a different"
|
||||
+ " `root` directory or manually download the dataset to this directory."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
|
@ -18,7 +18,12 @@ from rasterio.enums import Resampling
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, extract_archive, percentile_normalization
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
extract_archive,
|
||||
percentile_normalization,
|
||||
)
|
||||
|
||||
|
||||
class DFC2022(NonGeoDataset):
|
||||
|
@ -153,6 +158,7 @@ class DFC2022(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` is invalid
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
"""
|
||||
assert split in self.metadata
|
||||
self.root = root
|
||||
|
@ -258,11 +264,7 @@ class DFC2022(NonGeoDataset):
|
|||
return tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if checksum fails or the dataset is not downloaded
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
exists = []
|
||||
for split_info in self.metadata.values():
|
||||
|
@ -288,11 +290,7 @@ class DFC2022(NonGeoDataset):
|
|||
if all(exists):
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory, either specify a different"
|
||||
+ " `root` directory or manually download the dataset to this directory."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
|
@ -12,7 +12,7 @@ import pandas as pd
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, disambiguate_timestamp
|
||||
from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp
|
||||
|
||||
|
||||
class EDDMapS(GeoDataset):
|
||||
|
@ -48,7 +48,7 @@ class EDDMapS(GeoDataset):
|
|||
root: root directory where dataset can be found
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -56,7 +56,7 @@ class EDDMapS(GeoDataset):
|
|||
|
||||
filepath = os.path.join(root, "mappings.csv")
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Read CSV file
|
||||
data = pd.read_csv(
|
||||
|
|
|
@ -22,7 +22,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class EnviroAtlas(GeoDataset):
|
||||
|
@ -278,9 +278,8 @@ class EnviroAtlas(GeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
AssertionError: if ``splits`` or ``layers`` are not valid
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
for split in splits:
|
||||
assert split in self.splits
|
||||
|
@ -412,11 +411,7 @@ class EnviroAtlas(GeoDataset):
|
|||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
|
||||
def exists(filename: str) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))
|
||||
|
@ -432,11 +427,7 @@ class EnviroAtlas(GeoDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
|
|
@ -13,7 +13,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import download_url, extract_archive
|
||||
from .utils import DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class Esri2020(RasterDataset):
|
||||
|
@ -91,8 +91,7 @@ class Esri2020(RasterDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -106,11 +105,7 @@ class Esri2020(RasterDataset):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted file already exists
|
||||
if self.files:
|
||||
return
|
||||
|
@ -124,11 +119,7 @@ class Esri2020(RasterDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
|
|
@ -15,7 +15,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_and_extract_archive
|
||||
from .utils import DatasetNotFoundError, download_and_extract_archive
|
||||
|
||||
|
||||
class ETCI2021(NonGeoDataset):
|
||||
|
@ -98,8 +98,7 @@ class ETCI2021(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.metadata.keys()
|
||||
|
||||
|
@ -112,10 +111,7 @@ class ETCI2021(NonGeoDataset):
|
|||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self.files = self._load_files(self.root, self.split)
|
||||
|
||||
|
@ -243,11 +239,7 @@ class ETCI2021(NonGeoDataset):
|
|||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum of split.py does not match
|
||||
"""
|
||||
"""Download the dataset and extract it."""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
|
|
@ -13,7 +13,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import check_integrity, extract_archive
|
||||
from .utils import DatasetNotFoundError, check_integrity, extract_archive
|
||||
|
||||
|
||||
class EUDEM(RasterDataset):
|
||||
|
@ -105,7 +105,7 @@ class EUDEM(RasterDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -118,11 +118,7 @@ class EUDEM(RasterDataset):
|
|||
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted file already exists
|
||||
if self.files:
|
||||
return
|
||||
|
@ -138,11 +134,7 @@ class EUDEM(RasterDataset):
|
|||
extract_archive(zipfile)
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` "
|
||||
"either specify a different `root` directory or make sure you "
|
||||
"have manually downloaded the dataset as suggested in the documentation."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
|
@ -14,7 +14,13 @@ from matplotlib.figure import Figure
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
from .utils import check_integrity, download_url, extract_archive, rasterio_loader
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_url,
|
||||
extract_archive,
|
||||
rasterio_loader,
|
||||
)
|
||||
|
||||
|
||||
class EuroSAT(NonGeoClassificationDataset):
|
||||
|
@ -116,8 +122,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
|
||||
.. versionadded:: 0.3
|
||||
The *bands* parameter.
|
||||
|
@ -180,11 +185,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
return integrity
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
filepath = os.path.join(self.root, self.base_dir)
|
||||
if os.path.exists(filepath):
|
||||
|
@ -197,11 +198,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download and extract the dataset
|
||||
self._download()
|
||||
|
|
|
@ -17,7 +17,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_url, extract_archive
|
||||
from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
|
||||
|
||||
|
||||
def parse_pascal_voc(path: str) -> dict[str, Any]:
|
||||
|
@ -244,8 +244,7 @@ class FAIR1M(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
Added *split* and *download* parameters.
|
||||
|
@ -329,11 +328,7 @@ class FAIR1M(NonGeoDataset):
|
|||
return boxes, labels_tensor
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if checksum fails or the dataset is not found
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the directories already exist
|
||||
exists = []
|
||||
for directory in self.directories[self.split]:
|
||||
|
@ -362,18 +357,10 @@ class FAIR1M(NonGeoDataset):
|
|||
self._download()
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if download doesn't work correctly or checksums don't match
|
||||
"""
|
||||
"""Download the dataset and extract it."""
|
||||
paths = self.paths[self.split]
|
||||
urls = self.urls[self.split]
|
||||
md5s = self.md5s[self.split]
|
||||
|
|
|
@ -11,7 +11,7 @@ from matplotlib.figure import Figure
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
from .utils import download_url, extract_archive
|
||||
from .utils import DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class FireRisk(NonGeoClassificationDataset):
|
||||
|
@ -84,7 +84,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.splits
|
||||
self.root = root
|
||||
|
@ -98,11 +98,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
)
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
path = os.path.join(self.root, self.directory)
|
||||
if os.path.exists(path):
|
||||
|
@ -116,11 +112,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `root={self.root}` and `download=False`, "
|
||||
"either specify a different `root` directory or use `download=True` "
|
||||
"to automatically download the dataset."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download and extract the dataset
|
||||
self._download()
|
||||
|
|
|
@ -17,7 +17,12 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_and_extract_archive, extract_archive
|
||||
from .utils import (
|
||||
DatasetNotFoundError,
|
||||
check_integrity,
|
||||
download_and_extract_archive,
|
||||
extract_archive,
|
||||
)
|
||||
|
||||
|
||||
def parse_pascal_voc(path: str) -> dict[str, Any]:
|
||||
|
@ -119,8 +124,7 @@ class ForestDamage(NonGeoDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
@ -237,21 +241,13 @@ class ForestDamage(NonGeoDataset):
|
|||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise RuntimeError(
|
||||
"Dataset not found in `root` directory, either specify a different"
|
||||
+ " `root` directory or manually download "
|
||||
+ "the dataset to this directory."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# else download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum does not match
|
||||
"""
|
||||
"""Download the dataset and extract it."""
|
||||
download_and_extract_archive(
|
||||
self.url,
|
||||
self.root,
|
||||
|
|
|
@ -14,7 +14,7 @@ import pandas as pd
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset
|
||||
from .utils import BoundingBox
|
||||
from .utils import BoundingBox, DatasetNotFoundError
|
||||
|
||||
|
||||
def _disambiguate_timestamps(
|
||||
|
@ -86,7 +86,7 @@ class GBIF(GeoDataset):
|
|||
root: root directory where dataset can be found
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -94,7 +94,7 @@ class GBIF(GeoDataset):
|
|||
|
||||
files = glob.glob(os.path.join(root, "**.csv"))
|
||||
if not files:
|
||||
raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Read tab-delimited CSV file
|
||||
data = pd.read_table(
|
||||
|
|
|
@ -32,6 +32,7 @@ from torchvision.datasets.folder import default_loader as pil_loader
|
|||
|
||||
from .utils import (
|
||||
BoundingBox,
|
||||
DatasetNotFoundError,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
merge_samples,
|
||||
|
@ -390,7 +391,7 @@ class RasterDataset(GeoDataset):
|
|||
cache: if True, cache file handle to speed up repeated sampling
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -438,13 +439,7 @@ class RasterDataset(GeoDataset):
|
|||
i += 1
|
||||
|
||||
if i == 0:
|
||||
msg = (
|
||||
f"No {self.__class__.__name__} data was found "
|
||||
f"in `paths={self.paths!r}'`"
|
||||
)
|
||||
if self.bands:
|
||||
msg += f" with `bands={self.bands}`"
|
||||
raise FileNotFoundError(msg)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
if not self.separate_files:
|
||||
self.band_indexes = None
|
||||
|
@ -606,7 +601,7 @@ class VectorDataset(GeoDataset):
|
|||
rasterized into the mask
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``root``
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionadded:: 0.4
|
||||
The *label_name* parameter.
|
||||
|
@ -642,8 +637,7 @@ class VectorDataset(GeoDataset):
|
|||
i += 1
|
||||
|
||||
if i == 0:
|
||||
msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`"
|
||||
raise FileNotFoundError(msg)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self._crs = crs
|
||||
self._res = res
|
||||
|
|
|
@ -15,7 +15,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_and_extract_archive
|
||||
from .utils import DatasetNotFoundError, download_and_extract_archive
|
||||
|
||||
|
||||
class GID15(NonGeoDataset):
|
||||
|
@ -105,8 +105,7 @@ class GID15(NonGeoDataset):
|
|||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
RuntimeError: if ``download=False`` and data is not found, or checksums
|
||||
don't match
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.splits
|
||||
|
||||
|
@ -119,10 +118,7 @@ class GID15(NonGeoDataset):
|
|||
self._download()
|
||||
|
||||
if not self._check_integrity():
|
||||
raise RuntimeError(
|
||||
"Dataset not found or corrupted. "
|
||||
+ "You can use download=True to download it"
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
self.files = self._load_files(self.root, self.split)
|
||||
|
||||
|
@ -226,11 +222,7 @@ class GID15(NonGeoDataset):
|
|||
return True
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the checksum of split.py does not match
|
||||
"""
|
||||
"""Download the dataset and extract it."""
|
||||
if self._check_integrity():
|
||||
print("Files already downloaded and verified")
|
||||
return
|
||||
|
|
|
@ -14,7 +14,7 @@ from matplotlib.figure import Figure
|
|||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox, check_integrity, extract_archive
|
||||
from .utils import BoundingBox, DatasetNotFoundError, check_integrity, extract_archive
|
||||
|
||||
|
||||
class GlobBiomass(RasterDataset):
|
||||
|
@ -142,9 +142,8 @@ class GlobBiomass(RasterDataset):
|
|||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: if no files are found in ``paths``
|
||||
RuntimeError: if dataset is missing or checksum fails
|
||||
AssertionError: if measurement argument is invalid, or not a str
|
||||
DatasetNotFoundError: If dataset is not found.
|
||||
|
||||
.. versionchanged:: 0.5
|
||||
*root* was renamed to *paths*.
|
||||
|
@ -204,11 +203,7 @@ class GlobBiomass(RasterDataset):
|
|||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if dataset is missing or checksum fails
|
||||
"""
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the extracted file already exists
|
||||
if self.files:
|
||||
return
|
||||
|
@ -224,11 +219,7 @@ class GlobBiomass(RasterDataset):
|
|||
extract_archive(zipfile)
|
||||
return
|
||||
|
||||
raise RuntimeError(
|
||||
f"Dataset not found in `paths={self.paths!r}` "
|
||||
"either specify a different `root` directory or make sure you "
|
||||
"have manually downloaded the dataset as suggested in the documentation."
|
||||
)
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче