зеркало из https://github.com/microsoft/torchgeo.git
Datasets: add support for pathlib.Path (#2173)
* Added pathlib support for ``geo.py`` * Fixed failing ruff checks. * Fixed additional ruff fromating errors * Added complete ``pathlib`` support * Additional changes * Fixed ``cyclone.py`` issues * Additional fixes * Fixed ``isinstance`` and ``Path`` inconsistency and * Fixed mypy errors in ``cdl.py`` * geo/utils: all paths are Paths * datasets: all paths are Paths * Test Paths * Type checks only work for latest torchvision * Fix tests --------- Co-authored-by: Hitesh Tolani <hitesh.ht.2003@gmail.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
67eaeba79c
Коммит
44ce0073b2
|
@ -32,6 +32,7 @@ repos:
|
|||
- scikit-image>=0.22.0
|
||||
- torch>=2.3
|
||||
- torchmetrics>=0.10
|
||||
- torchvision>=0.18
|
||||
exclude: (build|data|dist|logo|logs|output)/
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
|
|
|
@ -17,7 +17,7 @@ from torchgeo.datasets import ADVANCE, DatasetNotFoundError
|
|||
pytest.importorskip('scipy', minversion='1.7.2')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TestADVANCE:
|
|||
md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad']
|
||||
monkeypatch.setattr(ADVANCE, 'urls', urls)
|
||||
monkeypatch.setattr(ADVANCE, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ADVANCE(root, transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -57,7 +57,7 @@ class TestADVANCE:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ADVANCE(str(tmp_path))
|
||||
ADVANCE(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: ADVANCE) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -21,7 +21,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
|
|||
)
|
||||
monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
return AbovegroundLiveWoodyBiomassDensity(
|
||||
root, transforms=transforms, download=True
|
||||
)
|
||||
|
@ -58,7 +58,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
|
|||
|
||||
def test_no_dataset(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
|
||||
AbovegroundLiveWoodyBiomassDensity(tmp_path)
|
||||
|
||||
def test_already_downloaded(
|
||||
self, dataset: AbovegroundLiveWoodyBiomassDensity
|
||||
|
|
|
@ -50,7 +50,7 @@ class TestAgriFieldNet:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
AgriFieldNet(str(tmp_path))
|
||||
AgriFieldNet(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: AgriFieldNet) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
|
|
@ -52,7 +52,7 @@ class TestAirphen:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Airphen(str(tmp_path))
|
||||
Airphen(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: Airphen) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -25,7 +25,7 @@ class TestAsterGDEM:
|
|||
def dataset(self, tmp_path: Path) -> AsterGDEM:
|
||||
zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip')
|
||||
shutil.unpack_archive(zipfile, tmp_path, 'zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return AsterGDEM(root, transforms=transforms)
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TestAsterGDEM:
|
|||
shutil.rmtree(tmp_path)
|
||||
os.makedirs(tmp_path)
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
AsterGDEM(str(tmp_path))
|
||||
AsterGDEM(tmp_path)
|
||||
|
||||
def test_getitem(self, dataset: AsterGDEM) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
|
|
@ -29,7 +29,7 @@ class TestBeninSmallHolderCashews:
|
|||
monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',))
|
||||
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2)
|
||||
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return BeninSmallHolderCashews(root, transforms=transforms, download=True)
|
||||
|
||||
|
@ -54,7 +54,7 @@ class TestBeninSmallHolderCashews:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
BeninSmallHolderCashews(str(tmp_path))
|
||||
BeninSmallHolderCashews(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -63,7 +63,7 @@ class TestBigEarthNet:
|
|||
monkeypatch.setattr(BigEarthNet, 'metadata', metadata)
|
||||
monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata)
|
||||
bands, num_classes, split = request.param
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return BigEarthNet(
|
||||
root, split, bands, num_classes, transforms, download=True, checksum=True
|
||||
|
@ -95,7 +95,7 @@ class TestBigEarthNet:
|
|||
|
||||
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
|
||||
BigEarthNet(
|
||||
root=str(tmp_path),
|
||||
root=tmp_path,
|
||||
bands=dataset.bands,
|
||||
split=dataset.split,
|
||||
num_classes=dataset.num_classes,
|
||||
|
@ -112,21 +112,21 @@ class TestBigEarthNet:
|
|||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
|
||||
)
|
||||
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
|
||||
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
|
||||
download_url(dataset.metadata['s1']['url'], root=tmp_path)
|
||||
download_url(dataset.metadata['s2']['url'], root=tmp_path)
|
||||
elif dataset.bands == 's1':
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata['s1']['directory'])
|
||||
)
|
||||
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
|
||||
download_url(dataset.metadata['s1']['url'], root=tmp_path)
|
||||
else:
|
||||
shutil.rmtree(
|
||||
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
|
||||
)
|
||||
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
|
||||
download_url(dataset.metadata['s2']['url'], root=tmp_path)
|
||||
|
||||
BigEarthNet(
|
||||
root=str(tmp_path),
|
||||
root=tmp_path,
|
||||
bands=dataset.bands,
|
||||
split=dataset.split,
|
||||
num_classes=dataset.num_classes,
|
||||
|
@ -135,7 +135,7 @@ class TestBigEarthNet:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
BigEarthNet(str(tmp_path))
|
||||
BigEarthNet(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: BigEarthNet) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestBioMassters:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
BioMassters(str(tmp_path))
|
||||
BioMassters(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: BioMassters) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestCanadianBuildingFootprints:
|
|||
url = os.path.join('tests', 'data', 'cbf') + os.sep
|
||||
monkeypatch.setattr(CanadianBuildingFootprints, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return CanadianBuildingFootprints(
|
||||
root, res=0.1, transforms=transforms, download=True, checksum=True
|
||||
|
@ -80,7 +80,7 @@ class TestCanadianBuildingFootprints:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CanadianBuildingFootprints(str(tmp_path))
|
||||
CanadianBuildingFootprints(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
|
||||
query = BoundingBox(2, 2, 2, 2, 2, 2)
|
||||
|
|
|
@ -24,7 +24,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestCDL:
|
|||
url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip')
|
||||
monkeypatch.setattr(CDL, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return CDL(
|
||||
root,
|
||||
|
@ -87,7 +87,7 @@ class TestCDL:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
CDL(root, years=[2023, 2022])
|
||||
|
@ -97,7 +97,7 @@ class TestCDL:
|
|||
AssertionError,
|
||||
match='CDL data product only exists for the following years:',
|
||||
):
|
||||
CDL(str(tmp_path), years=[1996])
|
||||
CDL(tmp_path, years=[1996])
|
||||
|
||||
def test_invalid_classes(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -121,7 +121,7 @@ class TestCDL:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CDL(str(tmp_path))
|
||||
CDL(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: CDL) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -18,7 +18,9 @@ from torchgeo.datasets import ChaBuD, DatasetNotFoundError
|
|||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(
|
||||
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
|
||||
) -> None:
|
||||
shutil.copy(url, os.path.join(root, filename))
|
||||
|
||||
|
||||
|
@ -34,7 +36,7 @@ class TestChaBuD:
|
|||
monkeypatch.setattr(ChaBuD, 'url', url)
|
||||
monkeypatch.setattr(ChaBuD, 'md5', md5)
|
||||
bands, split = request.param
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ChaBuD(
|
||||
root=root,
|
||||
|
@ -70,7 +72,7 @@ class TestChaBuD:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ChaBuD(str(tmp_path))
|
||||
ChaBuD(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -26,7 +26,7 @@ from torchgeo.datasets import (
|
|||
pytest.importorskip('zipfile_deflate64')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestChesapeake13:
|
|||
)
|
||||
monkeypatch.setattr(Chesapeake13, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -69,13 +69,13 @@ class TestChesapeake13:
|
|||
url = os.path.join(
|
||||
'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(url, root)
|
||||
Chesapeake13(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Chesapeake13(str(tmp_path), checksum=True)
|
||||
Chesapeake13(tmp_path, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: Chesapeake13) -> None:
|
||||
query = dataset.bounds
|
||||
|
@ -148,7 +148,7 @@ class TestChesapeakeCVPR:
|
|||
'_files',
|
||||
['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ChesapeakeCVPR(
|
||||
root,
|
||||
|
@ -180,7 +180,7 @@ class TestChesapeakeCVPR:
|
|||
ChesapeakeCVPR(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(
|
||||
os.path.join(
|
||||
'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip'
|
||||
|
@ -201,7 +201,7 @@ class TestChesapeakeCVPR:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ChesapeakeCVPR(str(tmp_path), checksum=True)
|
||||
ChesapeakeCVPR(tmp_path, checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestCloudCoverDetection:
|
|||
) -> CloudCoverDetection:
|
||||
url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1')
|
||||
monkeypatch.setattr(CloudCoverDetection, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return CloudCoverDetection(
|
||||
|
@ -55,7 +55,7 @@ class TestCloudCoverDetection:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CloudCoverDetection(str(tmp_path))
|
||||
CloudCoverDetection(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: CloudCoverDetection) -> None:
|
||||
sample = dataset[0]
|
||||
|
|
|
@ -20,7 +20,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ class TestCMSGlobalMangroveCanopy:
|
|||
|
||||
def test_no_dataset(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CMSGlobalMangroveCanopy(str(tmp_path))
|
||||
CMSGlobalMangroveCanopy(tmp_path)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join(
|
||||
|
@ -63,7 +63,7 @@ class TestCMSGlobalMangroveCanopy:
|
|||
'cms_mangrove_canopy',
|
||||
'CMS_Global_Map_Mangrove_Canopy_1665.zip',
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
CMSGlobalMangroveCanopy(root, country='Angola')
|
||||
|
||||
|
@ -73,7 +73,7 @@ class TestCMSGlobalMangroveCanopy:
|
|||
) as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
CMSGlobalMangroveCanopy(str(tmp_path), country='Angola', checksum=True)
|
||||
CMSGlobalMangroveCanopy(tmp_path, country='Angola', checksum=True)
|
||||
|
||||
def test_invalid_country(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -17,7 +17,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -46,7 +46,7 @@ class TestCOWCCounting:
|
|||
'0a4daed8c5f6c4e20faa6e38636e4346',
|
||||
]
|
||||
monkeypatch.setattr(COWCCounting, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return COWCCounting(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -78,7 +78,7 @@ class TestCOWCCounting:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
COWCCounting(str(tmp_path))
|
||||
COWCCounting(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: COWCCounting) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
@ -110,7 +110,7 @@ class TestCOWCDetection:
|
|||
'dccc2257e9c4a9dde2b4f84769804046',
|
||||
]
|
||||
monkeypatch.setattr(COWCDetection, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return COWCDetection(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -142,7 +142,7 @@ class TestCOWCDetection:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
COWCDetection(str(tmp_path))
|
||||
COWCDetection(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: COWCDetection) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -17,7 +17,7 @@ from torchgeo.datasets import CropHarvest, DatasetNotFoundError
|
|||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, md5: str) -> None:
|
||||
def download_url(url: str, root: str | Path, filename: str, md5: str) -> None:
|
||||
shutil.copy(url, os.path.join(root, filename))
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestCropHarvest:
|
|||
os.path.join('tests', 'data', 'cropharvest', 'labels.geojson'),
|
||||
)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
|
||||
dataset = CropHarvest(root, transforms, download=True, checksum=True)
|
||||
|
@ -61,16 +61,16 @@ class TestCropHarvest:
|
|||
assert len(dataset) == 5
|
||||
|
||||
def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None:
|
||||
CropHarvest(root=str(tmp_path), download=False)
|
||||
CropHarvest(root=tmp_path, download=False)
|
||||
|
||||
def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None:
|
||||
feature_path = os.path.join(tmp_path, 'features')
|
||||
shutil.rmtree(feature_path)
|
||||
CropHarvest(root=str(tmp_path), download=True)
|
||||
CropHarvest(root=tmp_path, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CropHarvest(str(tmp_path))
|
||||
CropHarvest(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: CropHarvest) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -30,7 +30,7 @@ class TestCV4AKenyaCropType:
|
|||
monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606'])
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2)
|
||||
monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return CV4AKenyaCropType(root, transforms=transforms, download=True)
|
||||
|
||||
|
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
CV4AKenyaCropType(str(tmp_path))
|
||||
CV4AKenyaCropType(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -28,7 +28,7 @@ class TestTropicalCyclone:
|
|||
url = os.path.join('tests', 'data', 'cyclone')
|
||||
monkeypatch.setattr(TropicalCyclone, 'url', url)
|
||||
monkeypatch.setattr(TropicalCyclone, 'size', 2)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return TropicalCyclone(root, split, transforms, download=True)
|
||||
|
@ -60,7 +60,7 @@ class TestTropicalCyclone:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
TropicalCyclone(str(tmp_path))
|
||||
TropicalCyclone(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: TropicalCyclone) -> None:
|
||||
sample = dataset[0]
|
||||
|
|
|
@ -39,16 +39,14 @@ class TestDeepGlobeLandCover:
|
|||
def test_extract(self, tmp_path: Path) -> None:
|
||||
root = os.path.join('tests', 'data', 'deepglobelandcover')
|
||||
filename = 'data.zip'
|
||||
shutil.copyfile(
|
||||
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
|
||||
)
|
||||
DeepGlobeLandCover(root=str(tmp_path))
|
||||
shutil.copyfile(os.path.join(root, filename), os.path.join(tmp_path, filename))
|
||||
DeepGlobeLandCover(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'data.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
DeepGlobeLandCover(root=str(tmp_path), checksum=True)
|
||||
DeepGlobeLandCover(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -56,7 +54,7 @@ class TestDeepGlobeLandCover:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
DeepGlobeLandCover(str(tmp_path))
|
||||
DeepGlobeLandCover(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: DeepGlobeLandCover) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -61,13 +61,13 @@ class TestDFC2022:
|
|||
os.path.join('tests', 'data', 'dfc2022', 'val.zip'),
|
||||
os.path.join(tmp_path, 'val.zip'),
|
||||
)
|
||||
DFC2022(root=str(tmp_path))
|
||||
DFC2022(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'labeled_train.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
DFC2022(root=str(tmp_path), checksum=True)
|
||||
DFC2022(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -75,7 +75,7 @@ class TestDFC2022:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
DFC2022(str(tmp_path))
|
||||
DFC2022(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: DFC2022) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestEDDMapS:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
EDDMapS(str(tmp_path))
|
||||
EDDMapS(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: EDDMapS) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -24,7 +24,7 @@ from torchgeo.datasets import (
|
|||
from torchgeo.samplers import RandomGeoSampler
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestEnviroAtlas:
|
|||
'_files',
|
||||
['pittsburgh_pa-2010_1m-train_tiles-debuffered', 'spatial_index.geojson'],
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return EnviroAtlas(
|
||||
root,
|
||||
|
@ -85,7 +85,7 @@ class TestEnviroAtlas:
|
|||
EnviroAtlas(root=dataset.root, download=True)
|
||||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(
|
||||
os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), root
|
||||
)
|
||||
|
@ -93,7 +93,7 @@ class TestEnviroAtlas:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
EnviroAtlas(str(tmp_path), checksum=True)
|
||||
EnviroAtlas(tmp_path, checksum=True)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestEsri2020:
|
|||
'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip',
|
||||
)
|
||||
monkeypatch.setattr(Esri2020, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return Esri2020(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -66,11 +66,11 @@ class TestEsri2020:
|
|||
'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip',
|
||||
)
|
||||
shutil.copy(url, tmp_path)
|
||||
Esri2020(str(tmp_path))
|
||||
Esri2020(tmp_path)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Esri2020(str(tmp_path), checksum=True)
|
||||
Esri2020(tmp_path, checksum=True)
|
||||
|
||||
def test_and(self, dataset: Esri2020) -> None:
|
||||
ds = dataset & dataset
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import ETCI2021, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ class TestETCI2021:
|
|||
},
|
||||
}
|
||||
monkeypatch.setattr(ETCI2021, 'metadata', metadata)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return ETCI2021(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -78,7 +78,7 @@ class TestETCI2021:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ETCI2021(str(tmp_path))
|
||||
ETCI2021(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: ETCI2021) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -28,7 +28,7 @@ class TestEUDEM:
|
|||
monkeypatch.setattr(EUDEM, 'md5s', md5s)
|
||||
zipfile = os.path.join('tests', 'data', 'eudem', 'eu_dem_v11_E30N10.zip')
|
||||
shutil.copy(zipfile, tmp_path)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return EUDEM(root, transforms=transforms)
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestEUDEM:
|
|||
assert len(dataset) == 1
|
||||
|
||||
def test_extracted_already(self, dataset: EUDEM) -> None:
|
||||
assert isinstance(dataset.paths, str)
|
||||
assert isinstance(dataset.paths, Path)
|
||||
zipfile = os.path.join(dataset.paths, 'eu_dem_v11_E30N10.zip')
|
||||
shutil.unpack_archive(zipfile, dataset.paths, 'zip')
|
||||
EUDEM(dataset.paths)
|
||||
|
@ -51,13 +51,13 @@ class TestEUDEM:
|
|||
shutil.rmtree(tmp_path)
|
||||
os.makedirs(tmp_path)
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
EUDEM(str(tmp_path))
|
||||
EUDEM(tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'eu_dem_v11_E30N10.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
EUDEM(str(tmp_path), checksum=True)
|
||||
EUDEM(tmp_path, checksum=True)
|
||||
|
||||
def test_and(self, dataset: EUDEM) -> None:
|
||||
ds = dataset & dataset
|
||||
|
|
|
@ -23,7 +23,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestEuroCrops:
|
|||
base_url = os.path.join('tests', 'data', 'eurocrops') + os.sep
|
||||
monkeypatch.setattr(EuroCrops, 'base_url', base_url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return EuroCrops(
|
||||
root, classes=classes, transforms=transforms, download=True, checksum=True
|
||||
|
@ -81,7 +81,7 @@ class TestEuroCrops:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
EuroCrops(str(tmp_path))
|
||||
EuroCrops(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: EuroCrops) -> None:
|
||||
query = BoundingBox(200, 200, 200, 200, 2, 2)
|
||||
|
|
|
@ -24,7 +24,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ class TestEuroSAT:
|
|||
'test': '4af60a00fdfdf8500572ae5360694b71',
|
||||
},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return base_class(
|
||||
root=root, split=split, transforms=transforms, download=True, checksum=True
|
||||
|
@ -90,18 +90,18 @@ class TestEuroSAT:
|
|||
assert len(ds) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
|
||||
EuroSAT(root=str(tmp_path), download=True)
|
||||
EuroSAT(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: EuroSAT, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
EuroSAT(root=str(tmp_path), download=False)
|
||||
download_url(dataset.url, root=tmp_path)
|
||||
EuroSAT(root=tmp_path, download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
EuroSAT(str(tmp_path))
|
||||
EuroSAT(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: EuroSAT) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
@ -114,7 +114,7 @@ class TestEuroSAT:
|
|||
plt.close()
|
||||
|
||||
def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
|
||||
dataset = EuroSAT(root=str(tmp_path), bands=('B03',))
|
||||
dataset = EuroSAT(root=tmp_path, bands=('B03',))
|
||||
with pytest.raises(
|
||||
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
|
||||
):
|
||||
|
|
|
@ -16,7 +16,9 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import FAIR1M, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(
|
||||
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
|
||||
) -> None:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
shutil.copy(url, os.path.join(root, filename))
|
||||
|
||||
|
@ -65,7 +67,7 @@ class TestFAIR1M:
|
|||
}
|
||||
monkeypatch.setattr(FAIR1M, 'urls', urls)
|
||||
monkeypatch.setattr(FAIR1M, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return FAIR1M(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -89,7 +91,7 @@ class TestFAIR1M:
|
|||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None:
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, download=True)
|
||||
FAIR1M(root=tmp_path, split=dataset.split, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: FAIR1M, tmp_path: Path
|
||||
|
@ -98,11 +100,11 @@ class TestFAIR1M:
|
|||
for filepath, url in zip(
|
||||
dataset.paths[dataset.split], dataset.urls[dataset.split]
|
||||
):
|
||||
output = os.path.join(str(tmp_path), filepath)
|
||||
output = os.path.join(tmp_path, filepath)
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
download_url(url, root=os.path.dirname(output), filename=output)
|
||||
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
|
||||
FAIR1M(root=tmp_path, split=dataset.split, checksum=True)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
|
||||
md5s = tuple(['randomhash'] * len(FAIR1M.md5s[dataset.split]))
|
||||
|
@ -111,17 +113,17 @@ class TestFAIR1M:
|
|||
for filepath, url in zip(
|
||||
dataset.paths[dataset.split], dataset.urls[dataset.split]
|
||||
):
|
||||
output = os.path.join(str(tmp_path), filepath)
|
||||
output = os.path.join(tmp_path, filepath)
|
||||
os.makedirs(os.path.dirname(output), exist_ok=True)
|
||||
download_url(url, root=os.path.dirname(output), filename=output)
|
||||
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
|
||||
FAIR1M(root=tmp_path, split=dataset.split, checksum=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
|
||||
shutil.rmtree(str(tmp_path))
|
||||
shutil.rmtree(tmp_path)
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
FAIR1M(root=str(tmp_path), split=dataset.split)
|
||||
FAIR1M(root=tmp_path, split=dataset.split)
|
||||
|
||||
def test_plot(self, dataset: FAIR1M) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, FireRisk
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ class TestFireRisk:
|
|||
md5 = 'db22106d61b10d855234b4a74db921ac'
|
||||
monkeypatch.setattr(FireRisk, 'md5', md5)
|
||||
monkeypatch.setattr(FireRisk, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return FireRisk(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -46,18 +46,18 @@ class TestFireRisk:
|
|||
assert len(dataset) == 5
|
||||
|
||||
def test_already_downloaded(self, dataset: FireRisk, tmp_path: Path) -> None:
|
||||
FireRisk(root=str(tmp_path), download=True)
|
||||
FireRisk(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: FireRisk, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(os.path.dirname(dataset.root))
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
FireRisk(root=str(tmp_path), download=False)
|
||||
download_url(dataset.url, root=tmp_path)
|
||||
FireRisk(root=tmp_path, download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
FireRisk(str(tmp_path))
|
||||
FireRisk(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: FireRisk) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -15,7 +15,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, ForestDamage
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ class TestForestDamage:
|
|||
|
||||
monkeypatch.setattr(ForestDamage, 'url', url)
|
||||
monkeypatch.setattr(ForestDamage, 'md5', md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ForestDamage(
|
||||
root=root, transforms=transforms, download=True, checksum=True
|
||||
|
@ -57,17 +57,17 @@ class TestForestDamage:
|
|||
'tests', 'data', 'forestdamage', 'Data_Set_Larch_Casebearer.zip'
|
||||
)
|
||||
shutil.copy(url, tmp_path)
|
||||
ForestDamage(root=str(tmp_path))
|
||||
ForestDamage(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'Data_Set_Larch_Casebearer.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
ForestDamage(root=str(tmp_path), checksum=True)
|
||||
ForestDamage(root=tmp_path, checksum=True)
|
||||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ForestDamage(str(tmp_path))
|
||||
ForestDamage(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: ForestDamage) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestGBIF:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
GBIF(str(tmp_path))
|
||||
GBIF(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: GBIF) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -36,7 +36,7 @@ class CustomGeoDataset(GeoDataset):
|
|||
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
|
||||
crs: CRS = CRS.from_epsg(4087),
|
||||
res: float = 1,
|
||||
paths: str | Iterable[str] | None = None,
|
||||
paths: str | Path | Iterable[str | Path] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.index.insert(0, tuple(bounds))
|
||||
|
@ -172,7 +172,7 @@ class TestGeoDataset:
|
|||
dataset & ds2 # type: ignore[operator]
|
||||
|
||||
def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
|
||||
paths = [str(tmp_path), str(tmp_path / 'non_existing_file.tif')]
|
||||
paths = [tmp_path, tmp_path / 'non_existing_file.tif']
|
||||
with pytest.warns(UserWarning, match='Path was ignored.'):
|
||||
assert len(CustomGeoDataset(paths=paths).files) == 0
|
||||
|
||||
|
@ -311,7 +311,7 @@ class TestRasterDataset:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
RasterDataset(str(tmp_path))
|
||||
RasterDataset(tmp_path)
|
||||
|
||||
def test_no_all_bands(self) -> None:
|
||||
root = os.path.join('tests', 'data', 'sentinel2')
|
||||
|
@ -380,7 +380,7 @@ class TestVectorDataset:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
VectorDataset(str(tmp_path))
|
||||
VectorDataset(tmp_path)
|
||||
|
||||
|
||||
class TestNonGeoDataset:
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import GID15, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -30,7 +30,7 @@ class TestGID15:
|
|||
monkeypatch.setattr(GID15, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'gid15', 'gid-15.zip')
|
||||
monkeypatch.setattr(GID15, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return GID15(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -59,7 +59,7 @@ class TestGID15:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
GID15(str(tmp_path))
|
||||
GID15(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: GID15) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestGlobBiomass:
|
|||
}
|
||||
|
||||
monkeypatch.setattr(GlobBiomass, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return GlobBiomass(root, transforms=transforms, checksum=True)
|
||||
|
||||
|
@ -55,13 +55,13 @@ class TestGlobBiomass:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
GlobBiomass(str(tmp_path), checksum=True)
|
||||
GlobBiomass(tmp_path, checksum=True)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'N00E020_agb.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
GlobBiomass(str(tmp_path), checksum=True)
|
||||
GlobBiomass(tmp_path, checksum=True)
|
||||
|
||||
def test_and(self, dataset: GlobBiomass) -> None:
|
||||
ds = dataset & dataset
|
||||
|
|
|
@ -19,7 +19,7 @@ from torchgeo.datasets import DatasetNotFoundError, IDTReeS
|
|||
pytest.importorskip('laspy', minversion='2')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ class TestIDTReeS:
|
|||
}
|
||||
split, task = request.param
|
||||
monkeypatch.setattr(IDTReeS, 'metadata', metadata)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -77,11 +77,11 @@ class TestIDTReeS:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
IDTReeS(str(tmp_path))
|
||||
IDTReeS(tmp_path)
|
||||
|
||||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'idtrees', '*.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
IDTReeS(root)
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestINaturalist:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
INaturalist(str(tmp_path))
|
||||
INaturalist(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: INaturalist) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -50,7 +50,7 @@ class TestInriaAerialImageLabeling:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: str) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
InriaAerialImageLabeling(str(tmp_path))
|
||||
InriaAerialImageLabeling(tmp_path)
|
||||
|
||||
def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:
|
||||
InriaAerialImageLabeling.md5 = 'randommd5hash123'
|
||||
|
|
|
@ -24,7 +24,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -36,7 +36,7 @@ class TestIOBench:
|
|||
url = os.path.join('tests', 'data', 'iobench', '{}.tar.gz')
|
||||
monkeypatch.setattr(IOBench, 'url', url)
|
||||
monkeypatch.setitem(IOBench.md5s, 'preprocessed', md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return IOBench(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -68,14 +68,14 @@ class TestIOBench:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'iobench', '*.tar.gz')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for tarfile in glob.iglob(pathname):
|
||||
shutil.copy(tarfile, root)
|
||||
IOBench(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
IOBench(str(tmp_path))
|
||||
IOBench(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: IOBench) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -25,7 +25,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestL7Irish:
|
|||
url = os.path.join('tests', 'data', 'l7irish', '{}.tar.gz')
|
||||
monkeypatch.setattr(L7Irish, 'url', url)
|
||||
monkeypatch.setattr(L7Irish, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return L7Irish(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -75,14 +75,14 @@ class TestL7Irish:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for tarfile in glob.iglob(pathname):
|
||||
shutil.copy(tarfile, root)
|
||||
L7Irish(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
L7Irish(str(tmp_path))
|
||||
L7Irish(tmp_path)
|
||||
|
||||
def test_plot_prediction(self, dataset: L7Irish) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
|
|
@ -25,7 +25,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -41,7 +41,7 @@ class TestL8Biome:
|
|||
url = os.path.join('tests', 'data', 'l8biome', '{}.tar.gz')
|
||||
monkeypatch.setattr(L8Biome, 'url', url)
|
||||
monkeypatch.setattr(L8Biome, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return L8Biome(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -75,14 +75,14 @@ class TestL8Biome:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for tarfile in glob.iglob(pathname):
|
||||
shutil.copy(tarfile, root)
|
||||
L8Biome(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
L8Biome(str(tmp_path))
|
||||
L8Biome(tmp_path)
|
||||
|
||||
def test_plot_prediction(self, dataset: L8Biome) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ class TestLandCoverAIGeo:
|
|||
monkeypatch.setattr(LandCoverAIGeo, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||
monkeypatch.setattr(LandCoverAIGeo, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -49,13 +49,13 @@ class TestLandCoverAIGeo:
|
|||
|
||||
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
|
||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(url, root)
|
||||
LandCoverAIGeo(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
LandCoverAIGeo(str(tmp_path))
|
||||
LandCoverAIGeo(tmp_path)
|
||||
|
||||
def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
@ -89,7 +89,7 @@ class TestLandCoverAI:
|
|||
monkeypatch.setattr(LandCoverAI, 'url', url)
|
||||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return LandCoverAI(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -115,13 +115,13 @@ class TestLandCoverAI:
|
|||
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
|
||||
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
|
||||
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(url, root)
|
||||
LandCoverAI(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
LandCoverAI(str(tmp_path))
|
||||
LandCoverAI(tmp_path)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -71,7 +71,7 @@ class TestLandsat8:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Landsat8(str(tmp_path))
|
||||
Landsat8(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: Landsat8) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import LEVIRCD, DatasetNotFoundError, LEVIRCDPlus
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -45,7 +45,7 @@ class TestLEVIRCD:
|
|||
}
|
||||
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
|
||||
monkeypatch.setattr(LEVIRCD, 'splits', splits)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return LEVIRCD(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -71,7 +71,7 @@ class TestLEVIRCD:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
LEVIRCD(str(tmp_path))
|
||||
LEVIRCD(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: LEVIRCD) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
@ -93,7 +93,7 @@ class TestLEVIRCDPlus:
|
|||
monkeypatch.setattr(LEVIRCDPlus, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip')
|
||||
monkeypatch.setattr(LEVIRCDPlus, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -119,7 +119,7 @@ class TestLEVIRCDPlus:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
LEVIRCDPlus(str(tmp_path))
|
||||
LEVIRCDPlus(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: LEVIRCDPlus) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, LoveDA
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ class TestLoveDA:
|
|||
|
||||
monkeypatch.setattr(LoveDA, 'info_dict', info_dict)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return LoveDA(
|
||||
|
@ -84,7 +84,7 @@ class TestLoveDA:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
LoveDA(str(tmp_path))
|
||||
LoveDA(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: LoveDA) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -18,7 +18,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, MapInWild
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -53,7 +53,7 @@ class TestMapInWild:
|
|||
urls = os.path.join('tests', 'data', 'mapinwild')
|
||||
monkeypatch.setattr(MapInWild, 'url', urls)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
|
||||
transforms = nn.Identity()
|
||||
|
@ -98,12 +98,12 @@ class TestMapInWild:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
MapInWild(root=str(tmp_path))
|
||||
MapInWild(root=tmp_path)
|
||||
|
||||
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'mapinwild', '*', '*')
|
||||
pathname_glob = glob.glob(pathname)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in pathname_glob:
|
||||
shutil.copy(zipfile, root)
|
||||
MapInWild(root, download=False)
|
||||
|
@ -111,7 +111,7 @@ class TestMapInWild:
|
|||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'mapinwild', '**', '*.zip')
|
||||
pathname_glob = glob.glob(pathname, recursive=True)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in pathname_glob:
|
||||
shutil.copy(zipfile, root)
|
||||
splitfile = os.path.join(
|
||||
|
@ -121,10 +121,10 @@ class TestMapInWild:
|
|||
with open(os.path.join(tmp_path, 'mask.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
MapInWild(root=str(tmp_path), download=True, checksum=True)
|
||||
MapInWild(root=tmp_path, download=True, checksum=True)
|
||||
|
||||
def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None:
|
||||
MapInWild(root=str(tmp_path), modality=dataset.modality, download=True)
|
||||
MapInWild(root=tmp_path, modality=dataset.modality, download=True)
|
||||
|
||||
def test_plot(self, dataset: MapInWild) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -39,18 +39,18 @@ class TestMillionAID:
|
|||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
MillionAID(str(tmp_path))
|
||||
MillionAID(tmp_path)
|
||||
|
||||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
url = os.path.join('tests', 'data', 'millionaid', 'train.zip')
|
||||
shutil.copy(url, tmp_path)
|
||||
MillionAID(str(tmp_path))
|
||||
MillionAID(tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'train.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
MillionAID(str(tmp_path), checksum=True)
|
||||
MillionAID(tmp_path, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: MillionAID) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -51,7 +51,7 @@ class TestNAIP:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
NAIP(str(tmp_path))
|
||||
NAIP(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: NAIP) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -45,7 +45,7 @@ class TestNASAMarineDebris:
|
|||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7']
|
||||
monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return NASAMarineDebris(root, transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -63,15 +63,15 @@ class TestNASAMarineDebris:
|
|||
def test_already_downloaded(
|
||||
self, dataset: NASAMarineDebris, tmp_path: Path
|
||||
) -> None:
|
||||
NASAMarineDebris(root=str(tmp_path), download=True)
|
||||
NASAMarineDebris(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: NASAMarineDebris, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
os.makedirs(str(tmp_path), exist_ok=True)
|
||||
os.makedirs(tmp_path, exist_ok=True)
|
||||
Collection().download(output_dir=str(tmp_path))
|
||||
NASAMarineDebris(root=str(tmp_path), download=False)
|
||||
NASAMarineDebris(root=tmp_path, download=False)
|
||||
|
||||
def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
|
||||
filenames = NASAMarineDebris.filenames
|
||||
|
@ -79,7 +79,7 @@ class TestNASAMarineDebris:
|
|||
with open(os.path.join(tmp_path, filename), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
|
||||
NASAMarineDebris(root=str(tmp_path), download=False, checksum=True)
|
||||
NASAMarineDebris(root=tmp_path, download=False, checksum=True)
|
||||
|
||||
def test_corrupted_new_download(
|
||||
self, tmp_path: Path, monkeypatch: MonkeyPatch
|
||||
|
@ -87,11 +87,11 @@ class TestNASAMarineDebris:
|
|||
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted)
|
||||
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
|
||||
NASAMarineDebris(root=tmp_path, download=True, checksum=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
NASAMarineDebris(str(tmp_path))
|
||||
NASAMarineDebris(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: NASAMarineDebris) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ class TestNCCM:
|
|||
}
|
||||
monkeypatch.setattr(NCCM, 'urls', urls)
|
||||
transforms = nn.Identity()
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
return NCCM(root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
def test_getitem(self, dataset: NCCM) -> None:
|
||||
|
@ -84,7 +84,7 @@ class TestNCCM:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
NCCM(str(tmp_path))
|
||||
NCCM(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: NCCM) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -42,7 +42,7 @@ class TestNLCD:
|
|||
)
|
||||
monkeypatch.setattr(NLCD, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return NLCD(
|
||||
root,
|
||||
|
@ -84,7 +84,7 @@ class TestNLCD:
|
|||
pathname = os.path.join(
|
||||
'tests', 'data', 'nlcd', 'nlcd_2019_land_cover_l48_20210604.zip'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
NLCD(root, years=[2019])
|
||||
|
||||
|
@ -93,7 +93,7 @@ class TestNLCD:
|
|||
AssertionError,
|
||||
match='NLCD data product only exists for the following years:',
|
||||
):
|
||||
NLCD(str(tmp_path), years=[1996])
|
||||
NLCD(tmp_path, years=[1996])
|
||||
|
||||
def test_invalid_classes(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -117,7 +117,7 @@ class TestNLCD:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
NLCD(str(tmp_path))
|
||||
NLCD(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: NLCD) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -26,7 +26,7 @@ from torchgeo.datasets import (
|
|||
class TestOpenBuildings:
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings:
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(
|
||||
os.path.join('tests', 'data', 'openbuildings', 'tiles.geojson'), root
|
||||
)
|
||||
|
@ -55,7 +55,7 @@ class TestOpenBuildings:
|
|||
|
||||
def test_not_download(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
OpenBuildings(str(tmp_path))
|
||||
OpenBuildings(tmp_path)
|
||||
|
||||
def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, '000_buildings.csv.gz'), 'w') as f:
|
||||
|
|
|
@ -18,7 +18,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import OSCD, DatasetNotFoundError, RGBBandsMissingError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -63,7 +63,7 @@ class TestOSCD:
|
|||
monkeypatch.setattr(OSCD, 'urls', urls)
|
||||
|
||||
bands, split = request.param
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return OSCD(
|
||||
root, split, bands, transforms=transforms, download=True, checksum=True
|
||||
|
@ -101,14 +101,14 @@ class TestOSCD:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'oscd', '*Onera*.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
OSCD(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
OSCD(str(tmp_path))
|
||||
OSCD(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: OSCD) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -17,7 +17,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import PASTIS, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ class TestPASTIS:
|
|||
monkeypatch.setattr(PASTIS, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip')
|
||||
monkeypatch.setattr(PASTIS, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
folds = request.param['folds']
|
||||
bands = request.param['bands']
|
||||
mode = request.param['mode']
|
||||
|
@ -75,19 +75,19 @@ class TestPASTIS:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(url, root)
|
||||
PASTIS(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
PASTIS(str(tmp_path))
|
||||
PASTIS(tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'PASTIS-R.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
PASTIS(root=str(tmp_path), checksum=True)
|
||||
PASTIS(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_fold(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -15,7 +15,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, PatternNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ class TestPatternNet:
|
|||
monkeypatch.setattr(PatternNet, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'patternnet', 'PatternNet.zip')
|
||||
monkeypatch.setattr(PatternNet, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return PatternNet(root, transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -42,18 +42,18 @@ class TestPatternNet:
|
|||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: PatternNet, tmp_path: Path) -> None:
|
||||
PatternNet(root=str(tmp_path), download=True)
|
||||
PatternNet(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: PatternNet, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
PatternNet(root=str(tmp_path), download=False)
|
||||
download_url(dataset.url, root=tmp_path)
|
||||
PatternNet(root=tmp_path, download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
PatternNet(str(tmp_path))
|
||||
PatternNet(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: PatternNet) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -43,9 +43,9 @@ class TestPotsdam2D:
|
|||
root = os.path.join('tests', 'data', 'potsdam')
|
||||
for filename in ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']:
|
||||
shutil.copyfile(
|
||||
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
|
||||
os.path.join(root, filename), os.path.join(tmp_path, filename)
|
||||
)
|
||||
Potsdam2D(root=str(tmp_path))
|
||||
Potsdam2D(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, '4_Ortho_RGBIR.zip'), 'w') as f:
|
||||
|
@ -53,7 +53,7 @@ class TestPotsdam2D:
|
|||
with open(os.path.join(tmp_path, '5_Labels_all.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
Potsdam2D(root=str(tmp_path), checksum=True)
|
||||
Potsdam2D(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -61,7 +61,7 @@ class TestPotsdam2D:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Potsdam2D(str(tmp_path))
|
||||
Potsdam2D(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: Potsdam2D) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -50,7 +50,7 @@ class TestPRISMA:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
PRISMA(str(tmp_path))
|
||||
PRISMA(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: PRISMA) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -18,7 +18,7 @@ from torchgeo.datasets import DatasetNotFoundError, QuakeSet
|
|||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ class TestQuakeSet:
|
|||
md5 = '127d0d6a1f82d517129535f50053a4c9'
|
||||
monkeypatch.setattr(QuakeSet, 'md5', md5)
|
||||
monkeypatch.setattr(QuakeSet, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return QuakeSet(
|
||||
|
@ -50,11 +50,11 @@ class TestQuakeSet:
|
|||
assert len(dataset) == 8
|
||||
|
||||
def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
|
||||
QuakeSet(root=str(tmp_path), download=True)
|
||||
QuakeSet(root=tmp_path, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
QuakeSet(str(tmp_path))
|
||||
QuakeSet(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: QuakeSet) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -15,7 +15,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, ReforesTree
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -31,7 +31,7 @@ class TestReforesTree:
|
|||
|
||||
monkeypatch.setattr(ReforesTree, 'url', url)
|
||||
monkeypatch.setattr(ReforesTree, 'md5', md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ReforesTree(
|
||||
root=root, transforms=transforms, download=True, checksum=True
|
||||
|
@ -57,17 +57,17 @@ class TestReforesTree:
|
|||
def test_not_extracted(self, tmp_path: Path) -> None:
|
||||
url = os.path.join('tests', 'data', 'reforestree', 'reforesTree.zip')
|
||||
shutil.copy(url, tmp_path)
|
||||
ReforesTree(root=str(tmp_path))
|
||||
ReforesTree(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(os.path.join(tmp_path, 'reforesTree.zip'), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
ReforesTree(root=str(tmp_path), checksum=True)
|
||||
ReforesTree(root=tmp_path, checksum=True)
|
||||
|
||||
def test_not_found(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ReforesTree(str(tmp_path))
|
||||
ReforesTree(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: ReforesTree) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -18,7 +18,7 @@ from torchgeo.datasets import RESISC45, DatasetNotFoundError
|
|||
pytest.importorskip('rarfile', minversion='4')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -52,7 +52,7 @@ class TestRESISC45:
|
|||
'test': '7760b1960c9a3ff46fb985810815e14d',
|
||||
},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return RESISC45(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -68,18 +68,18 @@ class TestRESISC45:
|
|||
assert len(dataset) == 9
|
||||
|
||||
def test_already_downloaded(self, dataset: RESISC45, tmp_path: Path) -> None:
|
||||
RESISC45(root=str(tmp_path), download=True)
|
||||
RESISC45(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: RESISC45, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
RESISC45(root=str(tmp_path), download=False)
|
||||
download_url(dataset.url, root=tmp_path)
|
||||
RESISC45(root=tmp_path, download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
RESISC45(str(tmp_path))
|
||||
RESISC45(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: RESISC45) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -33,7 +33,7 @@ class TestRwandaFieldBoundary:
|
|||
monkeypatch.setattr(RwandaFieldBoundary, 'url', url)
|
||||
monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1})
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return RwandaFieldBoundary(root, split, transforms=transforms, download=True)
|
||||
|
@ -60,7 +60,7 @@ class TestRwandaFieldBoundary:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
RwandaFieldBoundary(str(tmp_path))
|
||||
RwandaFieldBoundary(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
|
|
@ -18,7 +18,9 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, SeasoNet
|
||||
|
||||
|
||||
def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(
|
||||
url: str, root: str | Path, md5: str, *args: str, **kwargs: str
|
||||
) -> None:
|
||||
shutil.copy(url, root)
|
||||
torchgeo.datasets.utils.check_integrity(
|
||||
os.path.join(root, os.path.basename(url)), md5
|
||||
|
@ -95,7 +97,7 @@ class TestSeasoNet:
|
|||
'url',
|
||||
os.path.join('tests', 'data', 'seasonet', 'meta.csv'),
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split, seasons, bands, grids, concat_seasons = request.param
|
||||
transforms = nn.Identity()
|
||||
return SeasoNet(
|
||||
|
@ -141,14 +143,14 @@ class TestSeasoNet:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
paths = os.path.join('tests', 'data', 'seasonet', '*.*')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for path in glob.iglob(paths):
|
||||
shutil.copy(path, root)
|
||||
SeasoNet(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SeasoNet(str(tmp_path), download=False)
|
||||
SeasoNet(tmp_path, download=False)
|
||||
|
||||
def test_out_of_bounds(self, dataset: SeasoNet) -> None:
|
||||
with pytest.raises(IndexError):
|
||||
|
|
|
@ -22,7 +22,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -56,7 +56,7 @@ class TestSeasonalContrastS2:
|
|||
monkeypatch.setitem(
|
||||
SeasonalContrastS2.metadata['1m'], 'md5', '3bb3fcf90f5de7d5781ce0cb85fd20af'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
version, seasons, bands = request.param
|
||||
transforms = nn.Identity()
|
||||
return SeasonalContrastS2(
|
||||
|
@ -88,7 +88,7 @@ class TestSeasonalContrastS2:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'seco', '*.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for zipfile in glob.iglob(pathname):
|
||||
shutil.copy(zipfile, root)
|
||||
SeasonalContrastS2(root)
|
||||
|
@ -103,7 +103,7 @@ class TestSeasonalContrastS2:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SeasonalContrastS2(str(tmp_path))
|
||||
SeasonalContrastS2(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SeasonalContrastS2) -> None:
|
||||
x = dataset[0]
|
||||
|
|
|
@ -66,10 +66,10 @@ class TestSEN12MS:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SEN12MS(str(tmp_path), checksum=True)
|
||||
SEN12MS(tmp_path, checksum=True)
|
||||
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SEN12MS(str(tmp_path), checksum=False)
|
||||
SEN12MS(tmp_path, checksum=False)
|
||||
|
||||
def test_check_integrity_light(self) -> None:
|
||||
root = os.path.join('tests', 'data', 'sen12ms')
|
||||
|
|
|
@ -70,7 +70,7 @@ class TestSentinel1:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Sentinel1(str(tmp_path))
|
||||
Sentinel1(tmp_path)
|
||||
|
||||
def test_empty_bands(self) -> None:
|
||||
with pytest.raises(AssertionError, match="'bands' cannot be an empty list"):
|
||||
|
@ -132,7 +132,7 @@ class TestSentinel2:
|
|||
|
||||
def test_no_data(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Sentinel2(str(tmp_path))
|
||||
Sentinel2(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: Sentinel2) -> None:
|
||||
x = dataset[dataset.bounds]
|
||||
|
|
|
@ -19,7 +19,7 @@ from torchgeo.datasets import SKIPPD, DatasetNotFoundError
|
|||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ class TestSKIPPD:
|
|||
url = os.path.join('tests', 'data', 'skippd', '{}')
|
||||
monkeypatch.setattr(SKIPPD, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SKIPPD(
|
||||
root=root,
|
||||
|
@ -59,7 +59,7 @@ class TestSKIPPD:
|
|||
pathname = os.path.join(
|
||||
'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
SKIPPD(root=root, task=task)
|
||||
|
||||
|
@ -84,7 +84,7 @@ class TestSKIPPD:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SKIPPD(str(tmp_path))
|
||||
SKIPPD(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SKIPPD) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -58,7 +58,7 @@ class TestSo2Sat:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
So2Sat(str(tmp_path))
|
||||
So2Sat(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: So2Sat) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -52,7 +52,7 @@ class TestSouthAfricaCropType:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SouthAfricaCropType(str(tmp_path))
|
||||
SouthAfricaCropType(tmp_path)
|
||||
|
||||
def test_plot(self) -> None:
|
||||
path = os.path.join('tests', 'data', 'south_africa_crop_type')
|
||||
|
|
|
@ -21,7 +21,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -37,7 +37,7 @@ class TestSouthAmericaSoybean:
|
|||
)
|
||||
|
||||
monkeypatch.setattr(SouthAmericaSoybean, 'url', url)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
return SouthAmericaSoybean(
|
||||
paths=root,
|
||||
years=[2002, 2021],
|
||||
|
@ -70,7 +70,7 @@ class TestSouthAmericaSoybean:
|
|||
pathname = os.path.join(
|
||||
'tests', 'data', 'south_america_soybean', 'SouthAmerica_Soybean_2002.tif'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
SouthAmericaSoybean(root)
|
||||
|
||||
|
@ -89,7 +89,7 @@ class TestSouthAmericaSoybean:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SouthAmericaSoybean(str(tmp_path))
|
||||
SouthAmericaSoybean(tmp_path)
|
||||
|
||||
def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None:
|
||||
query = BoundingBox(0, 0, 0, 0, 0, 0)
|
||||
|
|
|
@ -68,7 +68,7 @@ class TestSpaceNet1:
|
|||
|
||||
# Refer https://github.com/python/mypy/issues/1032
|
||||
monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet1(
|
||||
root, image=request.param, transforms=transforms, download=True, api_key=''
|
||||
|
@ -93,7 +93,7 @@ class TestSpaceNet1:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet1(str(tmp_path))
|
||||
SpaceNet1(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet1) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
@ -118,7 +118,7 @@ class TestSpaceNet2:
|
|||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet2(
|
||||
root,
|
||||
|
@ -149,7 +149,7 @@ class TestSpaceNet2:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet2(str(tmp_path))
|
||||
SpaceNet2(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
|
||||
dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123'
|
||||
|
@ -177,7 +177,7 @@ class TestSpaceNet3:
|
|||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet3(
|
||||
root,
|
||||
|
@ -209,7 +209,7 @@ class TestSpaceNet3:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet3(str(tmp_path))
|
||||
SpaceNet3(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
|
||||
dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123'
|
||||
|
@ -240,7 +240,7 @@ class TestSpaceNet4:
|
|||
test_angles = ['nadir', 'off-nadir', 'very-off-nadir']
|
||||
|
||||
monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet4(
|
||||
root,
|
||||
|
@ -273,7 +273,7 @@ class TestSpaceNet4:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet4(str(tmp_path))
|
||||
SpaceNet4(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123'
|
||||
|
@ -303,7 +303,7 @@ class TestSpaceNet5:
|
|||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet5(
|
||||
root,
|
||||
|
@ -335,7 +335,7 @@ class TestSpaceNet5:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet5(str(tmp_path))
|
||||
SpaceNet5(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
|
||||
dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123'
|
||||
|
@ -359,7 +359,7 @@ class TestSpaceNet6:
|
|||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet6:
|
||||
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet6(
|
||||
root, image=request.param, transforms=transforms, download=True, api_key=''
|
||||
|
@ -405,7 +405,7 @@ class TestSpaceNet7:
|
|||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet7(
|
||||
root, split=request.param, transforms=transforms, download=True, api_key=''
|
||||
|
@ -429,7 +429,7 @@ class TestSpaceNet7:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet7(str(tmp_path))
|
||||
SpaceNet7(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123'
|
||||
|
|
|
@ -18,7 +18,7 @@ import torchgeo
|
|||
from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -61,7 +61,7 @@ class TestSSL4EOL:
|
|||
}
|
||||
monkeypatch.setattr(SSL4EOL, 'checksums', checksums)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split, seasons = request.param
|
||||
transforms = nn.Identity()
|
||||
return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True)
|
||||
|
@ -88,14 +88,14 @@ class TestSSL4EOL:
|
|||
|
||||
def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'ssl4eo', 'l', '*.tar.gz*')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for tarfile in glob.iglob(pathname):
|
||||
shutil.copy(tarfile, root)
|
||||
SSL4EOL(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SSL4EOL(str(tmp_path))
|
||||
SSL4EOL(tmp_path)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -148,7 +148,7 @@ class TestSSL4EOS12:
|
|||
os.path.join('tests', 'data', 'ssl4eo', 's12', filename),
|
||||
tmp_path / filename,
|
||||
)
|
||||
SSL4EOS12(str(tmp_path))
|
||||
SSL4EOS12(tmp_path)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -156,7 +156,7 @@ class TestSSL4EOS12:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SSL4EOS12(str(tmp_path))
|
||||
SSL4EOS12(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SSL4EOS12) -> None:
|
||||
sample = dataset[0]
|
||||
|
|
|
@ -25,7 +25,7 @@ from torchgeo.datasets import (
|
|||
)
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -43,7 +43,7 @@ class TestSSL4EOLBenchmark:
|
|||
monkeypatch.setattr(
|
||||
torchgeo.datasets.ssl4eo_benchmark, 'download_url', download_url
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
|
||||
url = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '{}.tar.gz')
|
||||
monkeypatch.setattr(SSL4EOLBenchmark, 'url', url)
|
||||
|
@ -140,14 +140,14 @@ class TestSSL4EOLBenchmark:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '*.tar.gz')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
for tarfile in glob.iglob(pathname):
|
||||
shutil.copy(tarfile, root)
|
||||
SSL4EOLBenchmark(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SSL4EOLBenchmark(str(tmp_path))
|
||||
SSL4EOLBenchmark(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SSL4EOLBenchmark) -> None:
|
||||
sample = dataset[0]
|
||||
|
|
|
@ -16,7 +16,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -34,7 +34,7 @@ class TestSustainBenchCropYield:
|
|||
url = os.path.join('tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip')
|
||||
monkeypatch.setattr(SustainBenchCropYield, 'url', url)
|
||||
monkeypatch.setattr(plt, 'show', lambda *args: None)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
countries = ['argentina', 'brazil', 'usa']
|
||||
transforms = nn.Identity()
|
||||
|
@ -49,7 +49,7 @@ class TestSustainBenchCropYield:
|
|||
pathname = os.path.join(
|
||||
'tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip'
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
SustainBenchCropYield(root)
|
||||
|
||||
|
@ -72,7 +72,7 @@ class TestSustainBenchCropYield:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SustainBenchCropYield(str(tmp_path))
|
||||
SustainBenchCropYield(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SustainBenchCropYield) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -17,7 +17,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, UCMerced
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ class TestUCMerced:
|
|||
'test': 'a01fa9f13333bb176fc1bfe26ff4c711',
|
||||
},
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return UCMerced(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -71,18 +71,18 @@ class TestUCMerced:
|
|||
assert len(ds) == 8
|
||||
|
||||
def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None:
|
||||
UCMerced(root=str(tmp_path), download=True)
|
||||
UCMerced(root=tmp_path, download=True)
|
||||
|
||||
def test_already_downloaded_not_extracted(
|
||||
self, dataset: UCMerced, tmp_path: Path
|
||||
) -> None:
|
||||
shutil.rmtree(dataset.root)
|
||||
download_url(dataset.url, root=str(tmp_path))
|
||||
UCMerced(root=str(tmp_path), download=False)
|
||||
download_url(dataset.url, root=tmp_path)
|
||||
UCMerced(root=tmp_path, download=False)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
UCMerced(str(tmp_path))
|
||||
UCMerced(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: UCMerced) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -17,7 +17,7 @@ import torchgeo.datasets.utils
|
|||
from torchgeo.datasets import DatasetNotFoundError, USAVars
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -73,7 +73,7 @@ class TestUSAVars:
|
|||
}
|
||||
monkeypatch.setattr(USAVars, 'split_metadata', split_metadata)
|
||||
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split, labels = request.param
|
||||
transforms = nn.Identity()
|
||||
|
||||
|
@ -109,7 +109,7 @@ class TestUSAVars:
|
|||
|
||||
def test_already_downloaded(self, tmp_path: Path) -> None:
|
||||
pathname = os.path.join('tests', 'data', 'usavars', 'uar.zip')
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
csvs = [
|
||||
'elevation.csv',
|
||||
|
@ -130,7 +130,7 @@ class TestUSAVars:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
USAVars(str(tmp_path))
|
||||
USAVars(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: USAVars) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
|
|
|
@ -65,7 +65,7 @@ def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
|||
return Collection()
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -85,7 +85,7 @@ def test_extract_archive(src: str, tmp_path: Path) -> None:
|
|||
pytest.importorskip('rarfile', minversion='4')
|
||||
if src.startswith('chesapeake'):
|
||||
pytest.importorskip('zipfile_deflate64')
|
||||
extract_archive(os.path.join('tests', 'data', src), str(tmp_path))
|
||||
extract_archive(os.path.join('tests', 'data', src), tmp_path)
|
||||
|
||||
|
||||
def test_unsupported_scheme() -> None:
|
||||
|
@ -98,8 +98,7 @@ def test_unsupported_scheme() -> None:
|
|||
def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
|
||||
download_and_extract_archive(
|
||||
os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'),
|
||||
str(tmp_path),
|
||||
os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), tmp_path
|
||||
)
|
||||
|
||||
|
||||
|
@ -108,7 +107,7 @@ def test_download_radiant_mlhub_dataset(
|
|||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
|
||||
download_radiant_mlhub_dataset('', str(tmp_path))
|
||||
download_radiant_mlhub_dataset('', tmp_path)
|
||||
|
||||
|
||||
def test_download_radiant_mlhub_collection(
|
||||
|
@ -116,7 +115,7 @@ def test_download_radiant_mlhub_collection(
|
|||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
download_radiant_mlhub_collection('', str(tmp_path))
|
||||
download_radiant_mlhub_collection('', tmp_path)
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
|
|
|
@ -49,9 +49,9 @@ class TestVaihingen2D:
|
|||
]
|
||||
for filename in filenames:
|
||||
shutil.copyfile(
|
||||
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
|
||||
os.path.join(root, filename), os.path.join(tmp_path, filename)
|
||||
)
|
||||
Vaihingen2D(root=str(tmp_path))
|
||||
Vaihingen2D(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
filenames = [
|
||||
|
@ -62,7 +62,7 @@ class TestVaihingen2D:
|
|||
with open(os.path.join(tmp_path, filename), 'w') as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
Vaihingen2D(root=str(tmp_path), checksum=True)
|
||||
Vaihingen2D(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -70,7 +70,7 @@ class TestVaihingen2D:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
Vaihingen2D(str(tmp_path))
|
||||
Vaihingen2D(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: Vaihingen2D) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -20,7 +20,7 @@ pytest.importorskip('pycocotools')
|
|||
pytest.importorskip('rarfile', minversion='4')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ class TestVHR10:
|
|||
monkeypatch.setitem(VHR10.target_meta, 'url', url)
|
||||
md5 = '567c4cd8c12624864ff04865de504c58'
|
||||
monkeypatch.setitem(VHR10.target_meta, 'md5', md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return VHR10(root, split, transforms, download=True, checksum=True)
|
||||
|
@ -78,7 +78,7 @@ class TestVHR10:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
VHR10(str(tmp_path))
|
||||
VHR10(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: VHR10) -> None:
|
||||
pytest.importorskip('skimage', minversion='0.19')
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestWesternUSALiveFuelMoisture:
|
|||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
md5 = 'ecbc9269dd27c4efe7aa887960054351'
|
||||
monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return WesternUSALiveFuelMoisture(
|
||||
root, transforms=transforms, download=True, api_key='', checksum=True
|
||||
|
@ -60,13 +60,13 @@ class TestWesternUSALiveFuelMoisture:
|
|||
'western_usa_live_fuel_moisture',
|
||||
'su_sar_moisture_content.tar.gz',
|
||||
)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
shutil.copy(pathname, root)
|
||||
WesternUSALiveFuelMoisture(root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
WesternUSALiveFuelMoisture(str(tmp_path))
|
||||
WesternUSALiveFuelMoisture(tmp_path)
|
||||
|
||||
def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:
|
||||
with pytest.raises(AssertionError, match='Invalid input variable name.'):
|
||||
|
|
|
@ -61,7 +61,7 @@ class TestXView2:
|
|||
),
|
||||
os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'),
|
||||
)
|
||||
XView2(root=str(tmp_path))
|
||||
XView2(root=tmp_path)
|
||||
|
||||
def test_corrupted(self, tmp_path: Path) -> None:
|
||||
with open(
|
||||
|
@ -73,7 +73,7 @@ class TestXView2:
|
|||
) as f:
|
||||
f.write('bad')
|
||||
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
|
||||
XView2(root=str(tmp_path), checksum=True)
|
||||
XView2(root=tmp_path, checksum=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -81,7 +81,7 @@ class TestXView2:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
XView2(str(tmp_path))
|
||||
XView2(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: XView2) -> None:
|
||||
x = dataset[0].copy()
|
||||
|
|
|
@ -17,7 +17,7 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriC
|
|||
pytest.importorskip('h5py', minversion='3.6')
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
|
||||
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TestZueriCrop:
|
|||
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
|
||||
monkeypatch.setattr(ZueriCrop, 'urls', urls)
|
||||
monkeypatch.setattr(ZueriCrop, 'md5s', md5s)
|
||||
root = str(tmp_path)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
|
||||
|
||||
|
@ -67,7 +67,7 @@ class TestZueriCrop:
|
|||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
ZueriCrop(str(tmp_path))
|
||||
ZueriCrop(tmp_path)
|
||||
|
||||
def test_invalid_bands(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_and_extract_archive, lazy_import
|
||||
from .utils import Path, download_and_extract_archive, lazy_import
|
||||
|
||||
|
||||
class ADVANCE(NonGeoDataset):
|
||||
|
@ -88,7 +88,7 @@ class ADVANCE(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -151,7 +151,7 @@ class ADVANCE(NonGeoDataset):
|
|||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_files(self, root: str) -> list[dict[str, str]]:
|
||||
def _load_files(self, root: Path) -> list[dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -169,7 +169,7 @@ class ADVANCE(NonGeoDataset):
|
|||
]
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str) -> Tensor:
|
||||
def _load_image(self, path: Path) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
|
@ -185,7 +185,7 @@ class ADVANCE(NonGeoDataset):
|
|||
tensor = tensor.permute((2, 0, 1))
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
def _load_target(self, path: Path) -> Tensor:
|
||||
"""Load the target audio for a single image.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
|
@ -14,7 +15,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import RasterDataset
|
||||
from .utils import download_url
|
||||
from .utils import Path, download_url
|
||||
|
||||
|
||||
class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
||||
|
@ -57,7 +58,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str] = 'data',
|
||||
paths: Path | Iterable[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
|
@ -105,7 +106,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
download_url(self.url, self.paths, self.base_filename)
|
||||
|
||||
with open(os.path.join(self.paths, self.base_filename)) as f:
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""AgriFieldNet India Challenge dataset."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
@ -16,7 +17,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import RGBBandsMissingError
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox
|
||||
from .utils import BoundingBox, Path
|
||||
|
||||
|
||||
class AgriFieldNet(RasterDataset):
|
||||
|
@ -115,7 +116,7 @@ class AgriFieldNet(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str] = 'data',
|
||||
paths: Path | Iterable[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
classes: list[int] = list(cmap.keys()),
|
||||
bands: Sequence[str] = all_bands,
|
||||
|
@ -167,10 +168,10 @@ class AgriFieldNet(RasterDataset):
|
|||
Returns:
|
||||
data, label, and field ids at that index
|
||||
"""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
|
||||
hits = self.index.intersection(tuple(query), objects=True)
|
||||
filepaths = cast(list[str], [hit.object for hit in hits])
|
||||
filepaths = cast(list[Path], [hit.object for hit in hits])
|
||||
|
||||
if not filepaths:
|
||||
raise IndexError(
|
||||
|
|
|
@ -12,6 +12,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import RasterDataset
|
||||
from .utils import Path
|
||||
|
||||
|
||||
class AsterGDEM(RasterDataset):
|
||||
|
@ -47,7 +48,7 @@ class AsterGDEM(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | list[str] = 'data',
|
||||
paths: Path | list[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import which
|
||||
from .utils import Path, which
|
||||
|
||||
|
||||
class BeninSmallHolderCashews(NonGeoDataset):
|
||||
|
@ -163,7 +163,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Sequence[str] = all_bands,
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive, sort_sentinel2_bands
|
||||
from .utils import Path, download_url, extract_archive, sort_sentinel2_bands
|
||||
|
||||
|
||||
class BigEarthNet(NonGeoDataset):
|
||||
|
@ -267,7 +267,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
bands: str = 'all',
|
||||
num_classes: int = 19,
|
||||
|
@ -486,7 +486,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
filepath = os.path.join(self.root, filename)
|
||||
self._extract(filepath)
|
||||
|
||||
def _download(self, url: str, filename: str, md5: str) -> None:
|
||||
def _download(self, url: str, filename: Path, md5: str) -> None:
|
||||
"""Download the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -499,13 +499,13 @@ class BigEarthNet(NonGeoDataset):
|
|||
url, self.root, filename=filename, md5=md5 if self.checksum else None
|
||||
)
|
||||
|
||||
def _extract(self, filepath: str) -> None:
|
||||
def _extract(self, filepath: Path) -> None:
|
||||
"""Extract the dataset.
|
||||
|
||||
Args:
|
||||
filepath: path to file to be extracted
|
||||
"""
|
||||
if not filepath.endswith('.csv'):
|
||||
if not str(filepath).endswith('.csv'):
|
||||
extract_archive(filepath)
|
||||
|
||||
def _onehot_labels_to_names(
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import percentile_normalization
|
||||
from .utils import Path, percentile_normalization
|
||||
|
||||
|
||||
class BioMassters(NonGeoDataset):
|
||||
|
@ -57,7 +57,7 @@ class BioMassters(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
sensors: Sequence[str] = ['S1', 'S2'],
|
||||
as_time_series: bool = False,
|
||||
|
@ -167,7 +167,7 @@ class BioMassters(NonGeoDataset):
|
|||
"""
|
||||
return len(self.df['num_index'].unique())
|
||||
|
||||
def _load_input(self, filenames: list[str]) -> Tensor:
|
||||
def _load_input(self, filenames: list[Path]) -> Tensor:
|
||||
"""Load the input imagery at the index.
|
||||
|
||||
Args:
|
||||
|
@ -186,7 +186,7 @@ class BioMassters(NonGeoDataset):
|
|||
arr = np.concatenate(arr_list, axis=0)
|
||||
return torch.tensor(arr.astype(np.int32))
|
||||
|
||||
def _load_target(self, filename: str) -> Tensor:
|
||||
def _load_target(self, filename: Path) -> Tensor:
|
||||
"""Load the target mask at the index.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""Canadian Building Footprints dataset."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
|
@ -13,7 +14,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import VectorDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
from .utils import Path, check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class CanadianBuildingFootprints(VectorDataset):
|
||||
|
@ -62,7 +63,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str] = 'data',
|
||||
paths: Path | Iterable[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float = 0.00001,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
|
@ -104,7 +105,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
Returns:
|
||||
True if dataset files are found and/or MD5s match, else False
|
||||
"""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
|
||||
filepath = os.path.join(self.paths, prov_terr + '.zip')
|
||||
if not check_integrity(filepath, md5 if self.checksum else None):
|
||||
|
@ -116,7 +117,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
|
||||
download_and_extract_archive(
|
||||
self.url + prov_terr + '.zip',
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""CDL dataset."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any
|
||||
|
||||
|
@ -14,7 +15,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import RasterDataset
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
from .utils import BoundingBox, Path, download_url, extract_archive
|
||||
|
||||
|
||||
class CDL(RasterDataset):
|
||||
|
@ -207,7 +208,7 @@ class CDL(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str] = 'data',
|
||||
paths: Path | Iterable[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
years: list[int] = [2023],
|
||||
|
@ -294,7 +295,7 @@ class CDL(RasterDataset):
|
|||
|
||||
# Check if the zip files have already been downloaded
|
||||
exists = []
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
for year in self.years:
|
||||
pathname = os.path.join(
|
||||
self.paths, self.zipfile_glob.replace('*', str(year))
|
||||
|
@ -327,7 +328,7 @@ class CDL(RasterDataset):
|
|||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
for year in self.years:
|
||||
zipfile_name = self.zipfile_glob.replace('*', str(year))
|
||||
pathname = os.path.join(self.paths, zipfile_name)
|
||||
|
|
|
@ -14,7 +14,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, lazy_import, percentile_normalization
|
||||
from .utils import Path, download_url, lazy_import, percentile_normalization
|
||||
|
||||
|
||||
class ChaBuD(NonGeoDataset):
|
||||
|
@ -75,7 +75,7 @@ class ChaBuD(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
bands: list[str] = all_bands,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
import abc
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, cast
|
||||
|
@ -26,7 +27,7 @@ from torch import Tensor
|
|||
from .errors import DatasetNotFoundError
|
||||
from .geo import GeoDataset, RasterDataset
|
||||
from .nlcd import NLCD
|
||||
from .utils import BoundingBox, download_url, extract_archive
|
||||
from .utils import BoundingBox, Path, download_url, extract_archive
|
||||
|
||||
|
||||
class Chesapeake(RasterDataset, abc.ABC):
|
||||
|
@ -91,7 +92,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | Iterable[str] = 'data',
|
||||
paths: Path | Iterable[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
|
@ -145,7 +146,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
return
|
||||
|
||||
# Check if the zip file has already been downloaded
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
if os.path.exists(os.path.join(self.paths, self.zipfile)):
|
||||
self._extract()
|
||||
return
|
||||
|
@ -164,7 +165,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
extract_archive(os.path.join(self.paths, self.zipfile))
|
||||
|
||||
def plot(
|
||||
|
@ -510,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
splits: Sequence[str] = ['de-train'],
|
||||
layers: Sequence[str] = ['naip-new', 'lc'],
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
|
@ -668,7 +669,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
|
||||
def exists(filename: str) -> bool:
|
||||
def exists(filename: Path) -> bool:
|
||||
return os.path.exists(os.path.join(self.root, filename))
|
||||
|
||||
# Check if the extracted files already exist
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import which
|
||||
from .utils import Path, which
|
||||
|
||||
|
||||
class CloudCoverDetection(NonGeoDataset):
|
||||
|
@ -61,7 +61,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
bands: Sequence[str] = all_bands,
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
"""CMS Global Mangrove Canopy dataset."""
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
@ -13,7 +14,7 @@ from rasterio.crs import CRS
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import RasterDataset
|
||||
from .utils import check_integrity, extract_archive
|
||||
from .utils import Path, check_integrity, extract_archive
|
||||
|
||||
|
||||
class CMSGlobalMangroveCanopy(RasterDataset):
|
||||
|
@ -169,7 +170,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
paths: str | list[str] = 'data',
|
||||
paths: Path | list[Path] = 'data',
|
||||
crs: CRS | None = None,
|
||||
res: float | None = None,
|
||||
measurement: str = 'agb',
|
||||
|
@ -228,7 +229,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
return
|
||||
|
||||
# Check if the zip file has already been downloaded
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
pathname = os.path.join(self.paths, self.zipfile)
|
||||
if os.path.exists(pathname):
|
||||
if self.checksum and not check_integrity(pathname, self.md5):
|
||||
|
@ -240,7 +241,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
assert isinstance(self.paths, str)
|
||||
assert isinstance(self.paths, str | pathlib.Path)
|
||||
pathname = os.path.join(self.paths, self.zipfile)
|
||||
extract_archive(pathname)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_and_extract_archive
|
||||
from .utils import Path, check_integrity, download_and_extract_archive
|
||||
|
||||
|
||||
class COWC(NonGeoDataset, abc.ABC):
|
||||
|
@ -65,7 +65,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import download_url, extract_archive, lazy_import
|
||||
from .utils import Path, download_url, extract_archive, lazy_import
|
||||
|
||||
|
||||
class CropHarvest(NonGeoDataset):
|
||||
|
@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
|
@ -157,7 +157,7 @@ class CropHarvest(NonGeoDataset):
|
|||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_features(self, root: str) -> list[dict[str, str]]:
|
||||
def _load_features(self, root: Path) -> list[dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -181,7 +181,7 @@ class CropHarvest(NonGeoDataset):
|
|||
files.append(dict(chip=chip_path, index=index, dataset=dataset))
|
||||
return files
|
||||
|
||||
def _load_labels(self, root: str) -> pd.DataFrame:
|
||||
def _load_labels(self, root: Path) -> pd.DataFrame:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -196,7 +196,7 @@ class CropHarvest(NonGeoDataset):
|
|||
df = pd.json_normalize(data['features'])
|
||||
return df
|
||||
|
||||
def _load_array(self, path: str) -> Tensor:
|
||||
def _load_array(self, path: Path) -> Tensor:
|
||||
"""Load an individual single pixel time series.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError, RGBBandsMissingError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import which
|
||||
from .utils import Path, which
|
||||
|
||||
|
||||
class CV4AKenyaCropType(NonGeoDataset):
|
||||
|
@ -104,7 +104,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
chip_size: int = 256,
|
||||
stride: int = 128,
|
||||
bands: Sequence[str] = all_bands,
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import which
|
||||
from .utils import Path, which
|
||||
|
||||
|
||||
class TropicalCyclone(NonGeoDataset):
|
||||
|
@ -53,7 +53,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
|
|
|
@ -16,6 +16,7 @@ from torch import Tensor
|
|||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import (
|
||||
Path,
|
||||
check_integrity,
|
||||
draw_semantic_segmentation_masks,
|
||||
extract_archive,
|
||||
|
@ -100,7 +101,7 @@ class DeepGlobeLandCover(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, extract_archive, percentile_normalization
|
||||
from .utils import Path, check_integrity, extract_archive, percentile_normalization
|
||||
|
||||
|
||||
class DFC2022(NonGeoDataset):
|
||||
|
@ -137,7 +137,7 @@ class DFC2022(NonGeoDataset):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = 'data',
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
checksum: bool = False,
|
||||
|
@ -224,7 +224,7 @@ class DFC2022(NonGeoDataset):
|
|||
|
||||
return files
|
||||
|
||||
def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor:
|
||||
def _load_image(self, path: Path, shape: Sequence[int] | None = None) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
|
@ -241,7 +241,7 @@ class DFC2022(NonGeoDataset):
|
|||
tensor = torch.from_numpy(array)
|
||||
return tensor
|
||||
|
||||
def _load_target(self, path: str) -> Tensor:
|
||||
def _load_target(self, path: Path) -> Tensor:
|
||||
"""Load the target mask for a single image.
|
||||
|
||||
Args:
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче