Datasets: consistent error messages for missing data (#1714)

* Datasets: consistent error messages for missing data

* Fix issues

* Increase test coverage

* mypy fixes

* isort fixes
This commit is contained in:
Adam J. Stewart 2023-11-07 07:53:10 -06:00 коммит произвёл GitHub
Родитель de56a5933c
Коммит 2c65e1d592
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
139 изменённых файлов: 727 добавлений и 963 удалений

Просмотреть файл

@ -178,7 +178,7 @@ BioMassters
^^^^^^^^^^^
.. autoclass:: BioMassters
Cloud Cover Detection
^^^^^^^^^^^^^^^^^^^^^
@ -464,3 +464,8 @@ Splitting Functions
.. autofunction:: random_grid_cell_assignment
.. autofunction:: roi_split
.. autofunction:: time_series_split
Errors
------
.. autoclass:: DatasetNotFoundError

Просмотреть файл

@ -14,7 +14,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ADVANCE
from torchgeo.datasets import ADVANCE, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@ -68,7 +68,7 @@ class TestADVANCE:
ADVANCE(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ADVANCE(str(tmp_path))
def test_mock_missing_module(

Просмотреть файл

@ -15,6 +15,7 @@ from rasterio.crs import CRS
import torchgeo
from torchgeo.datasets import (
AbovegroundLiveWoodyBiomassDensity,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@ -53,7 +54,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
assert isinstance(x["mask"], torch.Tensor)
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
def test_already_downloaded(

Просмотреть файл

@ -11,7 +11,13 @@ import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import AsterGDEM, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
AsterGDEM,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
class TestAsterGDEM:
@ -26,7 +32,7 @@ class TestAsterGDEM:
def test_datasetmissing(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
AsterGDEM(str(tmp_path))
def test_getitem(self, dataset: AsterGDEM) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import BeninSmallHolderCashews
from torchgeo.datasets import BeninSmallHolderCashews, DatasetNotFoundError
class Collection:
@ -73,7 +73,7 @@ class TestBeninSmallHolderCashews:
BeninSmallHolderCashews(root=dataset.root, download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
BeninSmallHolderCashews(str(tmp_path))
def test_invalid_bands(self) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import BigEarthNet
from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -134,10 +134,7 @@ class TestBigEarthNet:
)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
BigEarthNet(str(tmp_path))
def test_plot(self, dataset: BigEarthNet) -> None:

Просмотреть файл

@ -10,7 +10,7 @@ import matplotlib.pyplot as plt
import pytest
from _pytest.fixtures import SubRequest
from torchgeo.datasets import BioMassters
from torchgeo.datasets import BioMassters, DatasetNotFoundError
class TestBioMassters:
@ -36,8 +36,7 @@ class TestBioMassters:
BioMassters(dataset.root, sensors=["S3"])
def test_not_downloaded(self, tmp_path: Path) -> None:
match = "Dataset not found"
with pytest.raises(RuntimeError, match=match):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
BioMassters(str(tmp_path))
def test_plot(self, dataset: BioMassters) -> None:

Просмотреть файл

@ -16,6 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import (
BoundingBox,
CanadianBuildingFootprints,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@ -75,7 +76,7 @@ class TestCanadianBuildingFootprints:
dataset.plot(x, suptitle="Prediction")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CanadianBuildingFootprints(str(tmp_path))
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:

Просмотреть файл

@ -15,7 +15,13 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import CDL, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
CDL,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -111,7 +117,7 @@ class TestCDL:
plt.close()
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CDL(str(tmp_path))
def test_invalid_query(self, dataset: CDL) -> None:

Просмотреть файл

@ -18,6 +18,7 @@ from torchgeo.datasets import (
BoundingBox,
Chesapeake13,
ChesapeakeCVPR,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
@ -70,7 +71,7 @@ class TestChesapeake13:
Chesapeake13(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Chesapeake13(str(tmp_path), checksum=True)
def test_plot(self, dataset: Chesapeake13) -> None:
@ -193,7 +194,7 @@ class TestChesapeakeCVPR:
ChesapeakeCVPR(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ChesapeakeCVPR(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from pytest import MonkeyPatch
from torchgeo.datasets import CloudCoverDetection
from torchgeo.datasets import CloudCoverDetection, DatasetNotFoundError
class Collection:
@ -83,7 +83,7 @@ class TestCloudCoverDetection:
CloudCoverDetection(root=dataset.root, split="test", download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CloudCoverDetection(str(tmp_path))
def test_plot(self, dataset: CloudCoverDetection) -> None:

Просмотреть файл

@ -12,7 +12,12 @@ import torch.nn as nn
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import CMSGlobalMangroveCanopy, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
CMSGlobalMangroveCanopy,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -45,7 +50,7 @@ class TestCMSGlobalMangroveCanopy:
assert isinstance(x["mask"], torch.Tensor)
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CMSGlobalMangroveCanopy(str(tmp_path))
def test_already_downloaded(self, tmp_path: Path) -> None:

Просмотреть файл

@ -14,8 +14,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import COWCCounting, COWCDetection
from torchgeo.datasets.cowc import COWC
from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -78,7 +77,7 @@ class TestCOWCCounting:
COWCCounting(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
COWCCounting(str(tmp_path))
def test_plot(self, dataset: COWCCounting) -> None:
@ -142,7 +141,7 @@ class TestCOWCDetection:
COWCDetection(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
COWCDetection(str(tmp_path))
def test_plot(self, dataset: COWCDetection) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import CV4AKenyaCropType
from torchgeo.datasets import CV4AKenyaCropType, DatasetNotFoundError
class Collection:
@ -84,7 +84,7 @@ class TestCV4AKenyaCropType:
CV4AKenyaCropType(root=dataset.root, download=True, api_key="")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
CV4AKenyaCropType(str(tmp_path))
def test_invalid_tile(self, dataset: CV4AKenyaCropType) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import TropicalCyclone
from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone
class Collection:
@ -80,7 +80,7 @@ class TestTropicalCyclone:
TropicalCyclone(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
TropicalCyclone(str(tmp_path))
def test_plot(self, dataset: TropicalCyclone) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import DeepGlobeLandCover
from torchgeo.datasets import DatasetNotFoundError, DeepGlobeLandCover
class TestDeepGlobeLandCover:
@ -55,12 +55,7 @@ class TestDeepGlobeLandCover:
DeepGlobeLandCover(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(
RuntimeError,
match="Dataset not found in `root`, either"
+ " specify a different `root` directory or manually download"
+ " the dataset to this directory.",
):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
DeepGlobeLandCover(str(tmp_path))
def test_plot(self, dataset: DeepGlobeLandCover) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import DFC2022
from torchgeo.datasets import DFC2022, DatasetNotFoundError
class TestDFC2022:
@ -74,7 +74,7 @@ class TestDFC2022:
DFC2022(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
DFC2022(str(tmp_path))
def test_plot(self, dataset: DFC2022) -> None:

Просмотреть файл

@ -6,7 +6,13 @@ from pathlib import Path
import pytest
from torchgeo.datasets import BoundingBox, EDDMapS, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
EDDMapS,
IntersectionDataset,
UnionDataset,
)
class TestEDDMapS:
@ -31,7 +37,7 @@ class TestEDDMapS:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EDDMapS(str(tmp_path))
def test_invalid_query(self, dataset: EDDMapS) -> None:

Просмотреть файл

@ -16,6 +16,7 @@ from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
EnviroAtlas,
IntersectionDataset,
UnionDataset,
@ -88,7 +89,7 @@ class TestEnviroAtlas:
EnviroAtlas(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EnviroAtlas(str(tmp_path), checksum=True)
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:

Просмотреть файл

@ -13,7 +13,13 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, Esri2020, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
Esri2020,
IntersectionDataset,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -60,7 +66,7 @@ class TestEsri2020:
Esri2020(str(tmp_path))
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Esri2020(str(tmp_path), checksum=True)
def test_and(self, dataset: Esri2020) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021
from torchgeo.datasets import ETCI2021, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@ -77,7 +77,7 @@ class TestETCI2021:
ETCI2021(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ETCI2021(str(tmp_path))
def test_plot(self, dataset: ETCI2021) -> None:

Просмотреть файл

@ -12,7 +12,13 @@ import torch.nn as nn
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import EUDEM, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
EUDEM,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
class TestEUDEM:
@ -41,7 +47,7 @@ class TestEUDEM:
def test_no_dataset(self, tmp_path: Path) -> None:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EUDEM(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import EuroSAT, EuroSAT100
from torchgeo.datasets import DatasetNotFoundError, EuroSAT, EuroSAT100
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -92,10 +92,7 @@ class TestEuroSAT:
EuroSAT(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
EuroSAT(str(tmp_path))
def test_plot(self, dataset: EuroSAT) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import FAIR1M
from torchgeo.datasets import FAIR1M, DatasetNotFoundError
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
@ -120,7 +120,7 @@ class TestFAIR1M:
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
shutil.rmtree(str(tmp_path))
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
FAIR1M(root=str(tmp_path), split=dataset.split)
def test_plot(self, dataset: FAIR1M) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import FireRisk
from torchgeo.datasets import DatasetNotFoundError, FireRisk
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -56,7 +56,7 @@ class TestFireRisk:
FireRisk(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
FireRisk(str(tmp_path))
def test_plot(self, dataset: FireRisk) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ForestDamage
from torchgeo.datasets import DatasetNotFoundError, ForestDamage
def download_url(url: str, root: str, *args: str) -> None:
@ -66,7 +66,7 @@ class TestForestDamage:
ForestDamage(root=str(tmp_path), checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ForestDamage(str(tmp_path))
def test_plot(self, dataset: ForestDamage) -> None:

Просмотреть файл

@ -6,7 +6,13 @@ from pathlib import Path
import pytest
from torchgeo.datasets import GBIF, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
GBIF,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
class TestGBIF:
@ -31,7 +37,7 @@ class TestGBIF:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GBIF(str(tmp_path))
def test_invalid_query(self, dataset: GBIF) -> None:

Просмотреть файл

@ -16,6 +16,7 @@ from torch.utils.data import ConcatDataset
from torchgeo.datasets import (
NAIP,
BoundingBox,
DatasetNotFoundError,
GeoDataset,
IntersectionDataset,
NonGeoClassificationDataset,
@ -262,7 +263,7 @@ class TestRasterDataset:
sentinel[query]
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No RasterDataset data was found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
RasterDataset(str(tmp_path))
def test_no_all_bands(self) -> None:
@ -327,7 +328,7 @@ class TestVectorDataset:
dataset[query]
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No VectorDataset data was found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
VectorDataset(str(tmp_path))

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import GID15
from torchgeo.datasets import GID15, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
@ -58,7 +58,7 @@ class TestGID15:
GID15(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GID15(str(tmp_path))
def test_plot(self, dataset: GID15) -> None:

Просмотреть файл

@ -14,6 +14,7 @@ from rasterio.crs import CRS
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
GlobBiomass,
IntersectionDataset,
UnionDataset,
@ -50,7 +51,7 @@ class TestGlobBiomass:
GlobBiomass(dataset.paths)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
GlobBiomass(str(tmp_path), checksum=True)
def test_corrupted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -16,7 +16,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import IDTReeS
from torchgeo.datasets import DatasetNotFoundError, IDTReeS
pytest.importorskip("laspy", minversion="2")
@ -91,10 +91,7 @@ class TestIDTReeS:
IDTReeS(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
IDTReeS(str(tmp_path))
def test_not_extracted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -8,6 +8,7 @@ import pytest
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
INaturalist,
IntersectionDataset,
UnionDataset,
@ -36,7 +37,7 @@ class TestINaturalist:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
INaturalist(str(tmp_path))
def test_invalid_query(self, dataset: INaturalist) -> None:

Просмотреть файл

@ -11,7 +11,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import InriaAerialImageLabeling
from torchgeo.datasets import DatasetNotFoundError, InriaAerialImageLabeling
class TestInriaAerialImageLabeling:
@ -49,7 +49,7 @@ class TestInriaAerialImageLabeling:
InriaAerialImageLabeling(root=dataset.root)
def test_not_downloaded(self, tmp_path: str) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
InriaAerialImageLabeling(str(tmp_path))
def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:

Просмотреть файл

@ -14,7 +14,13 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, IntersectionDataset, L7Irish, UnionDataset
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
L7Irish,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -68,7 +74,7 @@ class TestL7Irish:
L7Irish(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
L7Irish(str(tmp_path))
def test_plot_prediction(self, dataset: L7Irish) -> None:

Просмотреть файл

@ -14,7 +14,13 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, IntersectionDataset, L8Biome, UnionDataset
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
L8Biome,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -68,7 +74,7 @@ class TestL8Biome:
L8Biome(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
L8Biome(str(tmp_path))
def test_plot_prediction(self, dataset: L8Biome) -> None:

Просмотреть файл

@ -14,7 +14,12 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import BoundingBox, LandCoverAI, LandCoverAIGeo
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
LandCoverAI,
LandCoverAIGeo,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -49,7 +54,7 @@ class TestLandCoverAIGeo:
LandCoverAIGeo(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LandCoverAIGeo(str(tmp_path))
def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
@ -115,7 +120,7 @@ class TestLandCoverAI:
LandCoverAI(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LandCoverAI(str(tmp_path))
def test_invalid_split(self) -> None:

Просмотреть файл

@ -12,7 +12,13 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, IntersectionDataset, Landsat8, UnionDataset
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
Landsat8,
UnionDataset,
)
class TestLandsat8:
@ -60,7 +66,7 @@ class TestLandsat8:
ds.plot(x)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Landsat8 data was found in "):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Landsat8(str(tmp_path))
def test_invalid_query(self, dataset: Landsat8) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import LEVIRCDPlus
from torchgeo.datasets import DatasetNotFoundError, LEVIRCDPlus
def download_url(url: str, root: str, *args: str) -> None:
@ -55,7 +55,7 @@ class TestLEVIRCDPlus:
LEVIRCDPlus(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LEVIRCDPlus(str(tmp_path))
def test_plot(self, dataset: LEVIRCDPlus) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import LoveDA
from torchgeo.datasets import DatasetNotFoundError, LoveDA
def download_url(url: str, root: str, *args: str) -> None:
@ -83,9 +83,7 @@ class TestLoveDA:
LoveDA(scene=["garden"])
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(
RuntimeError, match="Dataset not found at root directory or corrupted."
):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
LoveDA(str(tmp_path))
def test_plot(self, dataset: LoveDA) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import MapInWild
from torchgeo.datasets import DatasetNotFoundError, MapInWild
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -97,7 +97,7 @@ class TestMapInWild:
MapInWild(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
MapInWild(root=str(tmp_path))
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -11,7 +11,7 @@ import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from torchgeo.datasets import MillionAID
from torchgeo.datasets import DatasetNotFoundError, MillionAID
class TestMillionAID:
@ -38,7 +38,7 @@ class TestMillionAID:
assert len(dataset) == 2
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
MillionAID(str(tmp_path))
def test_not_extracted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -10,7 +10,13 @@ import torch
import torch.nn as nn
from rasterio.crs import CRS
from torchgeo.datasets import NAIP, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
NAIP,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
class TestNAIP:
@ -41,7 +47,7 @@ class TestNAIP:
plt.close()
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No NAIP data was found in "):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NAIP(str(tmp_path))
def test_invalid_query(self, dataset: NAIP) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch
import torch.nn as nn
from pytest import MonkeyPatch
from torchgeo.datasets import NASAMarineDebris
from torchgeo.datasets import DatasetNotFoundError, NASAMarineDebris
class Collection:
@ -90,10 +90,7 @@ class TestNASAMarineDebris:
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NASAMarineDebris(str(tmp_path))
def test_plot(self, dataset: NASAMarineDebris) -> None:

Просмотреть файл

@ -13,7 +13,13 @@ from pytest import MonkeyPatch
from rasterio.crs import CRS
import torchgeo.datasets.utils
from torchgeo.datasets import NLCD, BoundingBox, IntersectionDataset, UnionDataset
from torchgeo.datasets import (
NLCD,
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
UnionDataset,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -107,7 +113,7 @@ class TestNLCD:
plt.close()
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
NLCD(str(tmp_path))
def test_invalid_query(self, dataset: NLCD) -> None:

Просмотреть файл

@ -16,6 +16,7 @@ from rasterio.crs import CRS
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
OpenBuildings,
UnionDataset,
@ -52,16 +53,9 @@ class TestOpenBuildings:
assert isinstance(x["crs"], CRS)
assert isinstance(x["mask"], torch.Tensor)
def test_no_building_data_found(self, tmp_path: Path) -> None:
false_root = os.path.join(tmp_path, "empty")
os.makedirs(false_root)
shutil.copy(
os.path.join("tests", "data", "openbuildings", "tiles.geojson"), false_root
)
with pytest.raises(
RuntimeError, match="have manually downloaded the dataset as suggested "
):
OpenBuildings(false_root)
def test_not_download(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
OpenBuildings(str(tmp_path))
def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, "000_buildings.csv.gz"), "w") as f:
@ -69,12 +63,6 @@ class TestOpenBuildings:
with pytest.raises(RuntimeError, match="Dataset found, but corrupted."):
OpenBuildings(dataset.paths, checksum=True)
def test_no_meta_data_found(self, tmp_path: Path) -> None:
false_root = os.path.join(tmp_path, "empty")
os.makedirs(false_root)
with pytest.raises(FileNotFoundError, match="Meta data file"):
OpenBuildings(false_root)
def test_nothing_in_index(self, dataset: OpenBuildings, tmp_path: Path) -> None:
# change meta data to another 'title_url' so that there is no match found
with open(os.path.join(tmp_path, "tiles.geojson")) as f:
@ -84,7 +72,7 @@ class TestOpenBuildings:
with open(os.path.join(tmp_path, "tiles.geojson"), "w") as f:
json.dump(content, f)
with pytest.raises(FileNotFoundError, match="data was found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
OpenBuildings(dataset.paths)
def test_getitem(self, dataset: OpenBuildings) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import OSCD
from torchgeo.datasets import OSCD, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -107,7 +107,7 @@ class TestOSCD:
OSCD(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
OSCD(str(tmp_path))
def test_plot(self, dataset: OSCD) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import PASTIS
from torchgeo.datasets import PASTIS, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -80,7 +80,7 @@ class TestPASTIS:
PASTIS(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
PASTIS(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import PatternNet
from torchgeo.datasets import DatasetNotFoundError, PatternNet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -52,10 +52,7 @@ class TestPatternNet:
PatternNet(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
PatternNet(str(tmp_path))
def test_plot(self, dataset: PatternNet) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import Potsdam2D
from torchgeo.datasets import DatasetNotFoundError, Potsdam2D
class TestPotsdam2D:
@ -60,7 +60,7 @@ class TestPotsdam2D:
Potsdam2D(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Potsdam2D(str(tmp_path))
def test_plot(self, dataset: Potsdam2D) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ReforesTree
from torchgeo.datasets import DatasetNotFoundError, ReforesTree
def download_url(url: str, root: str, *args: str) -> None:
@ -66,7 +66,7 @@ class TestReforesTree:
ReforesTree(root=str(tmp_path), checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ReforesTree(str(tmp_path))
def test_plot(self, dataset: ReforesTree) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import RESISC45
from torchgeo.datasets import RESISC45, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -78,10 +78,7 @@ class TestRESISC45:
RESISC45(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
RESISC45(str(tmp_path))
def test_plot(self, dataset: RESISC45) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import RwandaFieldBoundary
from torchgeo.datasets import DatasetNotFoundError, RwandaFieldBoundary
class Collection:
@ -87,7 +87,7 @@ class TestRwandaFieldBoundary:
RwandaFieldBoundary(root=dataset.root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
RwandaFieldBoundary(str(tmp_path))
def test_corrupted(self, tmp_path: Path) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import SeasoNet
from torchgeo.datasets import DatasetNotFoundError, SeasoNet
def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None:
@ -147,7 +147,7 @@ class TestSeasoNet:
SeasoNet(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SeasoNet(str(tmp_path), download=False)
def test_out_of_bounds(self, dataset: SeasoNet) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import SeasonalContrastS2
from torchgeo.datasets import DatasetNotFoundError, SeasonalContrastS2
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -98,7 +98,7 @@ class TestSeasonalContrastS2:
SeasonalContrastS2(bands=["A1steaksauce"])
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SeasonalContrastS2(str(tmp_path))
def test_plot(self, dataset: SeasonalContrastS2) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import SEN12MS
from torchgeo.datasets import SEN12MS, DatasetNotFoundError
class TestSEN12MS:
@ -65,10 +65,10 @@ class TestSEN12MS:
SEN12MS(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SEN12MS(str(tmp_path), checksum=True)
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SEN12MS(str(tmp_path), checksum=False)
def test_check_integrity_light(self) -> None:

Просмотреть файл

@ -13,6 +13,7 @@ from rasterio.crs import CRS
from torchgeo.datasets import (
BoundingBox,
DatasetNotFoundError,
IntersectionDataset,
Sentinel1,
Sentinel2,
@ -64,7 +65,7 @@ class TestSentinel1:
plt.close()
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Sentinel1 data was found in "):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Sentinel1(str(tmp_path))
def test_empty_bands(self) -> None:
@ -123,7 +124,7 @@ class TestSentinel2:
assert isinstance(ds, UnionDataset)
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(FileNotFoundError, match="No Sentinel2 data was found in "):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Sentinel2(str(tmp_path))
def test_plot(self, dataset: Sentinel2) -> None:

Просмотреть файл

@ -16,7 +16,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import SKIPPD
from torchgeo.datasets import SKIPPD, DatasetNotFoundError
pytest.importorskip("h5py", minversion="3")
@ -105,7 +105,7 @@ class TestSKIPPD:
SKIPPD(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SKIPPD(str(tmp_path))
def test_plot(self, dataset: SKIPPD) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import So2Sat
from torchgeo.datasets import DatasetNotFoundError, So2Sat
pytest.importorskip("h5py", minversion="3")
@ -70,7 +70,7 @@ class TestSo2Sat:
So2Sat(bands=("OK", "BK"))
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
So2Sat(str(tmp_path))
def test_plot(self, dataset: So2Sat) -> None:

Просмотреть файл

@ -14,6 +14,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import (
DatasetNotFoundError,
SpaceNet1,
SpaceNet2,
SpaceNet3,
@ -91,7 +92,7 @@ class TestSpaceNet1:
SpaceNet1(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet1(str(tmp_path))
def test_plot(self, dataset: SpaceNet1) -> None:
@ -147,7 +148,7 @@ class TestSpaceNet2:
SpaceNet2(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet2(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
@ -207,7 +208,7 @@ class TestSpaceNet3:
SpaceNet3(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet3(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
@ -271,7 +272,7 @@ class TestSpaceNet4:
SpaceNet4(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet4(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
@ -333,7 +334,7 @@ class TestSpaceNet5:
SpaceNet5(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet5(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
@ -427,7 +428,7 @@ class TestSpaceNet7:
SpaceNet7(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SpaceNet7(str(tmp_path))
def test_collection_checksum(self, dataset: SpaceNet4) -> None:

Просмотреть файл

@ -15,7 +15,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo
from torchgeo.datasets import SSL4EOL, SSL4EOS12
from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -94,7 +94,7 @@ class TestSSL4EOL:
SSL4EOL(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOL(str(tmp_path))
def test_invalid_split(self) -> None:
@ -155,7 +155,7 @@ class TestSSL4EOS12:
SSL4EOS12(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOS12(str(tmp_path))
def test_plot(self, dataset: SSL4EOS12) -> None:

Просмотреть файл

@ -16,7 +16,13 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import CDL, NLCD, RasterDataset, SSL4EOLBenchmark
from torchgeo.datasets import (
CDL,
NLCD,
DatasetNotFoundError,
RasterDataset,
SSL4EOLBenchmark,
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -137,7 +143,7 @@ class TestSSL4EOLBenchmark:
SSL4EOLBenchmark(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SSL4EOLBenchmark(str(tmp_path))
def test_plot(self, dataset: SSL4EOLBenchmark) -> None:

Просмотреть файл

@ -13,7 +13,7 @@ from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import SustainBenchCropYield
from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -71,7 +71,7 @@ class TestSustainBenchCropYield:
SustainBenchCropYield(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
SustainBenchCropYield(str(tmp_path))
def test_plot(self, dataset: SustainBenchCropYield) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import UCMerced
from torchgeo.datasets import DatasetNotFoundError, UCMerced
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -81,10 +81,7 @@ class TestUCMerced:
UCMerced(root=str(tmp_path), download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
UCMerced(str(tmp_path))
def test_plot(self, dataset: UCMerced) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import USAVars
from torchgeo.datasets import DatasetNotFoundError, USAVars
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
@ -129,7 +129,7 @@ class TestUSAVars:
USAVars(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
USAVars(str(tmp_path))
def test_plot(self, dataset: USAVars) -> None:

Просмотреть файл

@ -18,10 +18,12 @@ import pytest
import torch
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torch.utils.data import Dataset
import torchgeo.datasets.utils
from torchgeo.datasets.utils import (
BoundingBox,
DatasetNotFoundError,
concat_samples,
disambiguate_timestamp,
download_and_extract_archive,
@ -36,6 +38,52 @@ from torchgeo.datasets.utils import (
)
class TestDatasetNotFoundError:
def test_none(self) -> None:
ds: Dataset[Any] = Dataset()
match = "Dataset not found."
with pytest.raises(DatasetNotFoundError, match=match):
raise DatasetNotFoundError(ds)
def test_root(self) -> None:
ds: Dataset[Any] = Dataset()
ds.root = "foo" # type: ignore[attr-defined]
match = "Dataset not found in `root='foo'` and cannot be automatically "
match += "downloaded, either specify a different `root` or manually "
match += "download the dataset."
with pytest.raises(DatasetNotFoundError, match=match):
raise DatasetNotFoundError(ds)
def test_paths(self) -> None:
ds: Dataset[Any] = Dataset()
ds.paths = "foo" # type: ignore[attr-defined]
match = "Dataset not found in `paths='foo'` and cannot be automatically "
match += "downloaded, either specify a different `paths` or manually "
match += "download the dataset."
with pytest.raises(DatasetNotFoundError, match=match):
raise DatasetNotFoundError(ds)
def test_root_download(self) -> None:
ds: Dataset[Any] = Dataset()
ds.root = "foo" # type: ignore[attr-defined]
ds.download = False # type: ignore[attr-defined]
match = "Dataset not found in `root='foo'` and `download=False`, either "
match += "specify a different `root` or use `download=True` to automatically "
match += "download the dataset."
with pytest.raises(DatasetNotFoundError, match=match):
raise DatasetNotFoundError(ds)
def test_paths_download(self) -> None:
ds: Dataset[Any] = Dataset()
ds.paths = "foo" # type: ignore[attr-defined]
ds.download = False # type: ignore[attr-defined]
match = "Dataset not found in `paths='foo'` and `download=False`, either "
match += "specify a different `paths` or use `download=True` to automatically "
match += "download the dataset."
with pytest.raises(DatasetNotFoundError, match=match):
raise DatasetNotFoundError(ds)
@pytest.fixture
def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
import_orig = builtins.__import__
@ -48,7 +96,7 @@ def mock_missing_module(monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(builtins, "__import__", mocked_import)
class Dataset:
class MLHubDataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
"tests", "data", "ref_african_crops_kenya_02", "*.tar.gz"
@ -66,8 +114,8 @@ class Collection:
shutil.copy(tarball, output_dir)
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset()
def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset:
return MLHubDataset()
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import Vaihingen2D
from torchgeo.datasets import DatasetNotFoundError, Vaihingen2D
class TestVaihingen2D:
@ -69,7 +69,7 @@ class TestVaihingen2D:
Vaihingen2D(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
Vaihingen2D(str(tmp_path))
def test_plot(self, dataset: Vaihingen2D) -> None:

Просмотреть файл

@ -16,7 +16,7 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
import torchgeo.datasets.utils
from torchgeo.datasets import VHR10
from torchgeo.datasets import VHR10, DatasetNotFoundError
pytest.importorskip("pycocotools")
@ -90,7 +90,7 @@ class TestVHR10:
VHR10(split="train")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found or corrupted."):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
VHR10(str(tmp_path))
def test_mock_missing_module(

Просмотреть файл

@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from pytest import MonkeyPatch
from torchgeo.datasets import WesternUSALiveFuelMoisture
from torchgeo.datasets import DatasetNotFoundError, WesternUSALiveFuelMoisture
class Collection:
@ -65,7 +65,7 @@ class TestWesternUSALiveFuelMoisture:
WesternUSALiveFuelMoisture(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
WesternUSALiveFuelMoisture(str(tmp_path))
def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:

Просмотреть файл

@ -12,7 +12,7 @@ import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import XView2
from torchgeo.datasets import DatasetNotFoundError, XView2
class TestXView2:
@ -80,7 +80,7 @@ class TestXView2:
XView2(split="foo")
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(RuntimeError, match="Dataset not found in `root` directory"):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
XView2(str(tmp_path))
def test_plot(self, dataset: XView2) -> None:

Просмотреть файл

@ -14,7 +14,7 @@ import torch.nn as nn
from pytest import MonkeyPatch
import torchgeo.datasets.utils
from torchgeo.datasets import ZueriCrop
from torchgeo.datasets import DatasetNotFoundError, ZueriCrop
pytest.importorskip("h5py", minversion="3")
@ -79,10 +79,7 @@ class TestZueriCrop:
ZueriCrop(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
err = "Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
with pytest.raises(RuntimeError, match=err):
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
ZueriCrop(str(tmp_path))
def test_mock_missing_module(

Просмотреть файл

@ -116,6 +116,7 @@ from .ucmerced import UCMerced
from .usavars import USAVars
from .utils import (
BoundingBox,
DatasetNotFoundError,
concat_samples,
merge_samples,
stack_samples,
@ -253,4 +254,6 @@ __all__ = (
"random_grid_cell_assignment",
"roi_split",
"time_series_split",
# Errors
"DatasetNotFoundError",
)

Просмотреть файл

@ -15,7 +15,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import download_and_extract_archive
from .utils import DatasetNotFoundError, download_and_extract_archive
class ADVANCE(NonGeoDataset):
@ -101,8 +101,7 @@ class ADVANCE(NonGeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@ -112,10 +111,7 @@ class ADVANCE(NonGeoDataset):
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
self.files = self._load_files(self.root)
self.classes = sorted({f["cls"] for f in self.files})
@ -218,11 +214,7 @@ class ADVANCE(NonGeoDataset):
return True
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum of split.py does not match
"""
"""Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return

Просмотреть файл

@ -13,7 +13,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import download_url
from .utils import DatasetNotFoundError, download_url
class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
@ -77,7 +77,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``paths``
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -90,22 +90,14 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing
"""
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
return
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()

Просмотреть файл

@ -10,6 +10,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import DatasetNotFoundError
class AsterGDEM(RasterDataset):
@ -65,8 +66,7 @@ class AsterGDEM(RasterDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -78,20 +78,12 @@ class AsterGDEM(RasterDataset):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing
"""
"""Verify the integrity of the dataset."""
# Check if the extracted files already exists
if self.files:
return
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded dataset tiles as suggested in the documentation."
)
raise DatasetNotFoundError(self)
def plot(
self,

Просмотреть файл

@ -18,7 +18,12 @@ from rasterio.crs import CRS
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import (
DatasetNotFoundError,
check_integrity,
download_radiant_mlhub_collection,
extract_archive,
)
# TODO: read geospatial information from stac.json files
@ -198,7 +203,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
verbose: if True, print messages when new tiles are loaded
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
@ -214,10 +219,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
self._download(api_key)
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
# Calculate the indices that we will use over all tiles
self.chips_metadata = []

Просмотреть файл

@ -17,7 +17,12 @@ from rasterio.enums import Resampling
from torch import Tensor
from .geo import NonGeoDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
from .utils import (
DatasetNotFoundError,
download_url,
extract_archive,
sort_sentinel2_bands,
)
class BigEarthNet(NonGeoDataset):
@ -285,6 +290,9 @@ class BigEarthNet(NonGeoDataset):
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits_metadata
assert bands in ["s1", "s2", "all"]
@ -434,11 +442,7 @@ class BigEarthNet(NonGeoDataset):
return target
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
keys = ["s1", "s2"] if self.bands == "all" else [self.bands]
urls = [self.metadata[k]["url"] for k in keys]
md5s = [self.metadata[k]["md5"] for k in keys]
@ -478,11 +482,7 @@ class BigEarthNet(NonGeoDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download and extract the dataset
for url, filename, md5 in zip(urls, filenames, md5s):

Просмотреть файл

@ -16,7 +16,7 @@ from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
from .utils import percentile_normalization
from .utils import DatasetNotFoundError, percentile_normalization
class BioMassters(NonGeoDataset):
@ -75,8 +75,9 @@ class BioMassters(NonGeoDataset):
as_time_series: whether or not to return all available
time-steps or just a single one for a given target location
RuntimeError:
Raises:
AssertionError: if ``split`` or ``sensors`` is invalid
DatasetNotFoundError: If dataset is not found.
"""
self.root = root
@ -212,7 +213,7 @@ class BioMassters(NonGeoDataset):
if all(exists):
return
raise RuntimeError(f"Dataset not found in `root={self.root}`.")
raise DatasetNotFoundError(self)
def plot(
self,

Просмотреть файл

@ -12,7 +12,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import VectorDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
class CanadianBuildingFootprints(VectorDataset):
@ -81,9 +81,7 @@ class CanadianBuildingFootprints(VectorDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` and data is not found, or
``checksum=True`` and checksums don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -95,10 +93,7 @@ class CanadianBuildingFootprints(VectorDataset):
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
super().__init__(paths, crs, res, transforms)

Просмотреть файл

@ -13,7 +13,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class CDL(RasterDataset):
@ -234,8 +234,7 @@ class CDL(RasterDataset):
Raises:
AssertionError: if ``years`` or ``classes`` are invalid
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionadded:: 0.5
The *years* and *classes* parameters.
@ -286,11 +285,7 @@ class CDL(RasterDataset):
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
return
@ -313,11 +308,7 @@ class CDL(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()

Просмотреть файл

@ -24,7 +24,7 @@ from rasterio.crs import CRS
from torch import Tensor
from .geo import GeoDataset, RasterDataset
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class Chesapeake(RasterDataset, abc.ABC):
@ -112,8 +112,7 @@ class Chesapeake(RasterDataset, abc.ABC):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -138,11 +137,7 @@ class Chesapeake(RasterDataset, abc.ABC):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted file already exists
if self.files:
return
@ -155,11 +150,7 @@ class Chesapeake(RasterDataset, abc.ABC):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()
@ -562,9 +553,8 @@ class ChesapeakeCVPR(GeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
for split in splits:
assert split in self.splits
@ -694,11 +684,7 @@ class ChesapeakeCVPR(GeoDataset):
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, filename))
@ -719,11 +705,7 @@ class ChesapeakeCVPR(GeoDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()

Просмотреть файл

@ -16,7 +16,12 @@ from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import (
DatasetNotFoundError,
check_integrity,
download_radiant_mlhub_collection,
extract_archive,
)
# TODO: read geospatial information from stac.json files
@ -123,7 +128,7 @@ class CloudCoverDetection(NonGeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.split = split
@ -137,10 +142,7 @@ class CloudCoverDetection(NonGeoDataset):
self._download(api_key)
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
self.chip_paths = self._load_collections()
@ -331,9 +333,6 @@ class CloudCoverDetection(NonGeoDataset):
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")

Просмотреть файл

@ -11,7 +11,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import check_integrity, extract_archive
from .utils import DatasetNotFoundError, check_integrity, extract_archive
class CMSGlobalMangroveCanopy(RasterDataset):
@ -192,9 +192,8 @@ class CMSGlobalMangroveCanopy(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing or checksum fails
AssertionError: if country or measurement arg are not str or invalid
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -225,11 +224,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted files already exist
if self.files:
return
@ -243,11 +238,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
self._extract()
return
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as instructed in the documentation."
)
raise DatasetNotFoundError(self)
def _extract(self) -> None:
"""Extract the dataset."""

Просмотреть файл

@ -16,7 +16,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import DatasetNotFoundError, check_integrity, download_and_extract_archive
class COWC(NonGeoDataset, abc.ABC):
@ -81,8 +81,7 @@ class COWC(NonGeoDataset, abc.ABC):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in ["train", "test"]
@ -95,10 +94,7 @@ class COWC(NonGeoDataset, abc.ABC):
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
self.images = []
self.targets = []

Просмотреть файл

@ -16,7 +16,12 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import (
DatasetNotFoundError,
check_integrity,
download_radiant_mlhub_collection,
extract_archive,
)
# TODO: read geospatial information from stac.json files
@ -141,7 +146,7 @@ class CV4AKenyaCropType(NonGeoDataset):
verbose: if True, print messages when new tiles are loaded
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self._validate_bands(bands)
@ -157,10 +162,7 @@ class CV4AKenyaCropType(NonGeoDataset):
self._download(api_key)
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
# Calculate the indices that we will use over all tiles
self.chips_metadata = []
@ -390,9 +392,6 @@ class CV4AKenyaCropType(NonGeoDataset):
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")

Просмотреть файл

@ -16,7 +16,12 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import (
DatasetNotFoundError,
check_integrity,
download_radiant_mlhub_collection,
extract_archive,
)
class TropicalCyclone(NonGeoDataset):
@ -86,7 +91,7 @@ class TropicalCyclone(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.md5s
@ -99,10 +104,7 @@ class TropicalCyclone(NonGeoDataset):
self._download(api_key)
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
output_dir = "_".join([self.collection_id, split, "source"])
filename = os.path.join(root, output_dir, "collection.json")
@ -206,9 +208,6 @@ class TropicalCyclone(NonGeoDataset):
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
if self._check_integrity():
print("Files already downloaded and verified")

Просмотреть файл

@ -15,6 +15,7 @@ from torch import Tensor
from .geo import NonGeoDataset
from .utils import (
DatasetNotFoundError,
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
@ -102,6 +103,9 @@ class DeepGlobeLandCover(NonGeoDataset):
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
DatasetNotFoundError: If dataset is not found.
"""
assert split in self.splits
self.root = root
@ -195,11 +199,7 @@ class DeepGlobeLandCover(NonGeoDataset):
return tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if checksum fails or the dataset is not downloaded
"""
"""Verify the integrity of the dataset."""
# Check if the files already exist
if os.path.exists(os.path.join(self.root, self.data_root)):
return
@ -213,11 +213,7 @@ class DeepGlobeLandCover(NonGeoDataset):
extract_archive(filepath)
return
# Check if the user requested to download the dataset
raise RuntimeError(
"Dataset not found in `root`, either specify a different"
+ " `root` directory or manually download the dataset to this directory."
)
raise DatasetNotFoundError(self)
def plot(
self,

Просмотреть файл

@ -18,7 +18,12 @@ from rasterio.enums import Resampling
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, extract_archive, percentile_normalization
from .utils import (
DatasetNotFoundError,
check_integrity,
extract_archive,
percentile_normalization,
)
class DFC2022(NonGeoDataset):
@ -153,6 +158,7 @@ class DFC2022(NonGeoDataset):
Raises:
AssertionError: if ``split`` is invalid
DatasetNotFoundError: If dataset is not found.
"""
assert split in self.metadata
self.root = root
@ -258,11 +264,7 @@ class DFC2022(NonGeoDataset):
return tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if checksum fails or the dataset is not downloaded
"""
"""Verify the integrity of the dataset."""
# Check if the files already exist
exists = []
for split_info in self.metadata.values():
@ -288,11 +290,7 @@ class DFC2022(NonGeoDataset):
if all(exists):
return
# Check if the user requested to download the dataset
raise RuntimeError(
"Dataset not found in `root` directory, either specify a different"
+ " `root` directory or manually download the dataset to this directory."
)
raise DatasetNotFoundError(self)
def plot(
self,

Просмотреть файл

@ -12,7 +12,7 @@ import pandas as pd
from rasterio.crs import CRS
from .geo import GeoDataset
from .utils import BoundingBox, disambiguate_timestamp
from .utils import BoundingBox, DatasetNotFoundError, disambiguate_timestamp
class EDDMapS(GeoDataset):
@ -48,7 +48,7 @@ class EDDMapS(GeoDataset):
root: root directory where dataset can be found
Raises:
FileNotFoundError: if no files are found in ``root``
DatasetNotFoundError: If dataset is not found.
"""
super().__init__()
@ -56,7 +56,7 @@ class EDDMapS(GeoDataset):
filepath = os.path.join(root, "mappings.csv")
if not os.path.exists(filepath):
raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
raise DatasetNotFoundError(self)
# Read CSV file
data = pd.read_csv(

Просмотреть файл

@ -22,7 +22,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import GeoDataset
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
class EnviroAtlas(GeoDataset):
@ -278,9 +278,8 @@ class EnviroAtlas(GeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``root``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
AssertionError: if ``splits`` or ``layers`` are not valid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
for split in splits:
assert split in self.splits
@ -412,11 +411,7 @@ class EnviroAtlas(GeoDataset):
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
def exists(filename: str) -> bool:
return os.path.exists(os.path.join(self.root, "enviroatlas_lotp", filename))
@ -432,11 +427,7 @@ class EnviroAtlas(GeoDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()

Просмотреть файл

@ -13,7 +13,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import download_url, extract_archive
from .utils import DatasetNotFoundError, download_url, extract_archive
class Esri2020(RasterDataset):
@ -91,8 +91,7 @@ class Esri2020(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -106,11 +105,7 @@ class Esri2020(RasterDataset):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted file already exists
if self.files:
return
@ -124,11 +119,7 @@ class Esri2020(RasterDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download the dataset
self._download()

Просмотреть файл

@ -15,7 +15,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import download_and_extract_archive
from .utils import DatasetNotFoundError, download_and_extract_archive
class ETCI2021(NonGeoDataset):
@ -98,8 +98,7 @@ class ETCI2021(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.metadata.keys()
@ -112,10 +111,7 @@ class ETCI2021(NonGeoDataset):
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
self.files = self._load_files(self.root, self.split)
@ -243,11 +239,7 @@ class ETCI2021(NonGeoDataset):
return True
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum of split.py does not match
"""
"""Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return

Просмотреть файл

@ -13,7 +13,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import check_integrity, extract_archive
from .utils import DatasetNotFoundError, check_integrity, extract_archive
class EUDEM(RasterDataset):
@ -105,7 +105,7 @@ class EUDEM(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``paths``
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -118,11 +118,7 @@ class EUDEM(RasterDataset):
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted file already exists
if self.files:
return
@ -138,11 +134,7 @@ class EUDEM(RasterDataset):
extract_archive(zipfile)
return
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as suggested in the documentation."
)
raise DatasetNotFoundError(self)
def plot(
self,

Просмотреть файл

@ -14,7 +14,13 @@ from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
from .utils import check_integrity, download_url, extract_archive, rasterio_loader
from .utils import (
DatasetNotFoundError,
check_integrity,
download_url,
extract_archive,
rasterio_loader,
)
class EuroSAT(NonGeoClassificationDataset):
@ -116,8 +122,7 @@ class EuroSAT(NonGeoClassificationDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
.. versionadded:: 0.3
The *bands* parameter.
@ -180,11 +185,7 @@ class EuroSAT(NonGeoClassificationDataset):
return integrity
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the files already exist
filepath = os.path.join(self.root, self.base_dir)
if os.path.exists(filepath):
@ -197,11 +198,7 @@ class EuroSAT(NonGeoClassificationDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()

Просмотреть файл

@ -17,7 +17,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_url, extract_archive
from .utils import DatasetNotFoundError, check_integrity, download_url, extract_archive
def parse_pascal_voc(path: str) -> dict[str, Any]:
@ -244,8 +244,7 @@ class FAIR1M(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
Added *split* and *download* parameters.
@ -329,11 +328,7 @@ class FAIR1M(NonGeoDataset):
return boxes, labels_tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if checksum fails or the dataset is not found
"""
"""Verify the integrity of the dataset."""
# Check if the directories already exist
exists = []
for directory in self.directories[self.split]:
@ -362,18 +357,10 @@ class FAIR1M(NonGeoDataset):
self._download()
return
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
RuntimeError: if download doesn't work correctly or checksums don't match
"""
"""Download the dataset and extract it."""
paths = self.paths[self.split]
urls = self.urls[self.split]
md5s = self.md5s[self.split]

Просмотреть файл

@ -11,7 +11,7 @@ from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
from .utils import download_url, extract_archive
from .utils import DatasetNotFoundError, download_url, extract_archive
class FireRisk(NonGeoClassificationDataset):
@ -84,7 +84,7 @@ class FireRisk(NonGeoClassificationDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
self.root = root
@ -98,11 +98,7 @@ class FireRisk(NonGeoClassificationDataset):
)
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if ``download=False`` but dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the files already exist
path = os.path.join(self.root, self.directory)
if os.path.exists(path):
@ -116,11 +112,7 @@ class FireRisk(NonGeoClassificationDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
f"Dataset not found in `root={self.root}` and `download=False`, "
"either specify a different `root` directory or use `download=True` "
"to automatically download the dataset."
)
raise DatasetNotFoundError(self)
# Download and extract the dataset
self._download()

Просмотреть файл

@ -17,7 +17,12 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import check_integrity, download_and_extract_archive, extract_archive
from .utils import (
DatasetNotFoundError,
check_integrity,
download_and_extract_archive,
extract_archive,
)
def parse_pascal_voc(path: str) -> dict[str, Any]:
@ -119,8 +124,7 @@ class ForestDamage(NonGeoDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
self.root = root
self.transforms = transforms
@ -237,21 +241,13 @@ class ForestDamage(NonGeoDataset):
# Check if the user requested to download the dataset
if not self.download:
raise RuntimeError(
"Dataset not found in `root` directory, either specify a different"
+ " `root` directory or manually download "
+ "the dataset to this directory."
)
raise DatasetNotFoundError(self)
# else download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum does not match
"""
"""Download the dataset and extract it."""
download_and_extract_archive(
self.url,
self.root,

Просмотреть файл

@ -14,7 +14,7 @@ import pandas as pd
from rasterio.crs import CRS
from .geo import GeoDataset
from .utils import BoundingBox
from .utils import BoundingBox, DatasetNotFoundError
def _disambiguate_timestamps(
@ -86,7 +86,7 @@ class GBIF(GeoDataset):
root: root directory where dataset can be found
Raises:
FileNotFoundError: if no files are found in ``root``
DatasetNotFoundError: If dataset is not found.
"""
super().__init__()
@ -94,7 +94,7 @@ class GBIF(GeoDataset):
files = glob.glob(os.path.join(root, "**.csv"))
if not files:
raise FileNotFoundError(f"Dataset not found in `root={self.root}`")
raise DatasetNotFoundError(self)
# Read tab-delimited CSV file
data = pd.read_table(

Просмотреть файл

@ -32,6 +32,7 @@ from torchvision.datasets.folder import default_loader as pil_loader
from .utils import (
BoundingBox,
DatasetNotFoundError,
concat_samples,
disambiguate_timestamp,
merge_samples,
@ -390,7 +391,7 @@ class RasterDataset(GeoDataset):
cache: if True, cache file handle to speed up repeated sampling
Raises:
FileNotFoundError: if no files are found in ``paths``
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -438,13 +439,7 @@ class RasterDataset(GeoDataset):
i += 1
if i == 0:
msg = (
f"No {self.__class__.__name__} data was found "
f"in `paths={self.paths!r}'`"
)
if self.bands:
msg += f" with `bands={self.bands}`"
raise FileNotFoundError(msg)
raise DatasetNotFoundError(self)
if not self.separate_files:
self.band_indexes = None
@ -606,7 +601,7 @@ class VectorDataset(GeoDataset):
rasterized into the mask
Raises:
FileNotFoundError: if no files are found in ``root``
DatasetNotFoundError: If dataset is not found.
.. versionadded:: 0.4
The *label_name* parameter.
@ -642,8 +637,7 @@ class VectorDataset(GeoDataset):
i += 1
if i == 0:
msg = f"No {self.__class__.__name__} data was found in `root='{paths}'`"
raise FileNotFoundError(msg)
raise DatasetNotFoundError(self)
self._crs = crs
self._res = res

Просмотреть файл

@ -15,7 +15,7 @@ from PIL import Image
from torch import Tensor
from .geo import NonGeoDataset
from .utils import download_and_extract_archive
from .utils import DatasetNotFoundError, download_and_extract_archive
class GID15(NonGeoDataset):
@ -105,8 +105,7 @@ class GID15(NonGeoDataset):
Raises:
AssertionError: if ``split`` argument is invalid
RuntimeError: if ``download=False`` and data is not found, or checksums
don't match
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
@ -119,10 +118,7 @@ class GID15(NonGeoDataset):
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted. "
+ "You can use download=True to download it"
)
raise DatasetNotFoundError(self)
self.files = self._load_files(self.root, self.split)
@ -226,11 +222,7 @@ class GID15(NonGeoDataset):
return True
def _download(self) -> None:
"""Download the dataset and extract it.
Raises:
AssertionError: if the checksum of split.py does not match
"""
"""Download the dataset and extract it."""
if self._check_integrity():
print("Files already downloaded and verified")
return

Просмотреть файл

@ -14,7 +14,7 @@ from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
from .utils import BoundingBox, check_integrity, extract_archive
from .utils import BoundingBox, DatasetNotFoundError, check_integrity, extract_archive
class GlobBiomass(RasterDataset):
@ -142,9 +142,8 @@ class GlobBiomass(RasterDataset):
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
FileNotFoundError: if no files are found in ``paths``
RuntimeError: if dataset is missing or checksum fails
AssertionError: if measurement argument is invalid, or not a str
DatasetNotFoundError: If dataset is not found.
.. versionchanged:: 0.5
*root* was renamed to *paths*.
@ -204,11 +203,7 @@ class GlobBiomass(RasterDataset):
return sample
def _verify(self) -> None:
"""Verify the integrity of the dataset.
Raises:
RuntimeError: if dataset is missing or checksum fails
"""
"""Verify the integrity of the dataset."""
# Check if the extracted file already exists
if self.files:
return
@ -224,11 +219,7 @@ class GlobBiomass(RasterDataset):
extract_archive(zipfile)
return
raise RuntimeError(
f"Dataset not found in `paths={self.paths!r}` "
"either specify a different `root` directory or make sure you "
"have manually downloaded the dataset as suggested in the documentation."
)
raise DatasetNotFoundError(self)
def plot(
self,

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше