Datasets: add support for pathlib.Path (#2173)

* Added pathlib support for  ``geo.py``

* Fixed failing ruff checks.

* Fixed additional ruff fromating errors

* Added complete ``pathlib`` support

* Additional changes

* Fixed ``cyclone.py`` issues

* Additional fixes

* Fixed ``isinstance`` and ``Path`` inconsistency and

* Fixed mypy errors in ``cdl.py``

* geo/utils: all paths are Paths

* datasets: all paths are Paths

* Test Paths

* Type checks only work for latest torchvision

* Fix tests

---------

Co-authored-by: Hitesh Tolani <hitesh.ht.2003@gmail.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Hitesh Tolani 2024-07-18 15:46:41 +05:30 коммит произвёл GitHub
Родитель 67eaeba79c
Коммит 44ce0073b2
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
158 изменённых файлов: 658 добавлений и 598 удалений

Просмотреть файл

@ -32,6 +32,7 @@ repos:
- scikit-image>=0.22.0
- torch>=2.3
- torchmetrics>=0.10
- torchvision>=0.18
exclude: (build|data|dist|logo|logs|output)/
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0

Просмотреть файл

@ -17,7 +17,7 @@ from torchgeo.datasets import ADVANCE, DatasetNotFoundError
pytest.importorskip('scipy', minversion='1.7.2')
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -33,7 +33,7 @@ class TestADVANCE:
md5s = ['43acacecebecd17a82bc2c1e719fd7e4', '039b7baa47879a8a4e32b9dd8287f6ad']
monkeypatch.setattr(ADVANCE, 'urls', urls)
monkeypatch.setattr(ADVANCE, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ADVANCE(root, transforms, download=True, checksum=True)
@ -57,7 +57,7 @@ class TestADVANCE:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ADVANCE(str(tmp_path))
ADVANCE(tmp_path)
def test_plot(self, dataset: ADVANCE) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -21,7 +21,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -42,7 +42,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
)
monkeypatch.setattr(AbovegroundLiveWoodyBiomassDensity, 'url', url)
root = str(tmp_path)
root = tmp_path
return AbovegroundLiveWoodyBiomassDensity(
root, transforms=transforms, download=True
)
@ -58,7 +58,7 @@ class TestAbovegroundLiveWoodyBiomassDensity:
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AbovegroundLiveWoodyBiomassDensity(str(tmp_path))
AbovegroundLiveWoodyBiomassDensity(tmp_path)
def test_already_downloaded(
self, dataset: AbovegroundLiveWoodyBiomassDensity

Просмотреть файл

@ -50,7 +50,7 @@ class TestAgriFieldNet:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AgriFieldNet(str(tmp_path))
AgriFieldNet(tmp_path)
def test_plot(self, dataset: AgriFieldNet) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -52,7 +52,7 @@ class TestAirphen:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Airphen(str(tmp_path))
Airphen(tmp_path)
def test_invalid_query(self, dataset: Airphen) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -25,7 +25,7 @@ class TestAsterGDEM:
def dataset(self, tmp_path: Path) -> AsterGDEM:
zipfile = os.path.join('tests', 'data', 'astergdem', 'astergdem.zip')
shutil.unpack_archive(zipfile, tmp_path, 'zip')
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return AsterGDEM(root, transforms=transforms)
@ -33,7 +33,7 @@ class TestAsterGDEM:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
AsterGDEM(str(tmp_path))
AsterGDEM(tmp_path)
def test_getitem(self, dataset: AsterGDEM) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -29,7 +29,7 @@ class TestBeninSmallHolderCashews:
monkeypatch.setattr(BeninSmallHolderCashews, 'dates', ('20191105',))
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_height', 2)
monkeypatch.setattr(BeninSmallHolderCashews, 'tile_width', 2)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return BeninSmallHolderCashews(root, transforms=transforms, download=True)
@ -54,7 +54,7 @@ class TestBeninSmallHolderCashews:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BeninSmallHolderCashews(str(tmp_path))
BeninSmallHolderCashews(tmp_path)
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import BigEarthNet, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -63,7 +63,7 @@ class TestBigEarthNet:
monkeypatch.setattr(BigEarthNet, 'metadata', metadata)
monkeypatch.setattr(BigEarthNet, 'splits_metadata', splits_metadata)
bands, num_classes, split = request.param
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return BigEarthNet(
root, split, bands, num_classes, transforms, download=True, checksum=True
@ -95,7 +95,7 @@ class TestBigEarthNet:
def test_already_downloaded(self, dataset: BigEarthNet, tmp_path: Path) -> None:
BigEarthNet(
root=str(tmp_path),
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
@ -112,21 +112,21 @@ class TestBigEarthNet:
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
)
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
download_url(dataset.metadata['s1']['url'], root=tmp_path)
download_url(dataset.metadata['s2']['url'], root=tmp_path)
elif dataset.bands == 's1':
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s1']['directory'])
)
download_url(dataset.metadata['s1']['url'], root=str(tmp_path))
download_url(dataset.metadata['s1']['url'], root=tmp_path)
else:
shutil.rmtree(
os.path.join(dataset.root, dataset.metadata['s2']['directory'])
)
download_url(dataset.metadata['s2']['url'], root=str(tmp_path))
download_url(dataset.metadata['s2']['url'], root=tmp_path)
BigEarthNet(
root=str(tmp_path),
root=tmp_path,
bands=dataset.bands,
split=dataset.split,
num_classes=dataset.num_classes,
@ -135,7 +135,7 @@ class TestBigEarthNet:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BigEarthNet(str(tmp_path))
BigEarthNet(tmp_path)
def test_plot(self, dataset: BigEarthNet) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -37,7 +37,7 @@ class TestBioMassters:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
BioMassters(str(tmp_path))
BioMassters(tmp_path)
def test_plot(self, dataset: BioMassters) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -41,7 +41,7 @@ class TestCanadianBuildingFootprints:
url = os.path.join('tests', 'data', 'cbf') + os.sep
monkeypatch.setattr(CanadianBuildingFootprints, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return CanadianBuildingFootprints(
root, res=0.1, transforms=transforms, download=True, checksum=True
@ -80,7 +80,7 @@ class TestCanadianBuildingFootprints:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CanadianBuildingFootprints(str(tmp_path))
CanadianBuildingFootprints(tmp_path)
def test_invalid_query(self, dataset: CanadianBuildingFootprints) -> None:
query = BoundingBox(2, 2, 2, 2, 2, 2)

Просмотреть файл

@ -24,7 +24,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -41,7 +41,7 @@ class TestCDL:
url = os.path.join('tests', 'data', 'cdl', '{}_30m_cdls.zip')
monkeypatch.setattr(CDL, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return CDL(
root,
@ -87,7 +87,7 @@ class TestCDL:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'cdl', '*_30m_cdls.zip')
root = str(tmp_path)
root = tmp_path
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
CDL(root, years=[2023, 2022])
@ -97,7 +97,7 @@ class TestCDL:
AssertionError,
match='CDL data product only exists for the following years:',
):
CDL(str(tmp_path), years=[1996])
CDL(tmp_path, years=[1996])
def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
@ -121,7 +121,7 @@ class TestCDL:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CDL(str(tmp_path))
CDL(tmp_path)
def test_invalid_query(self, dataset: CDL) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -18,7 +18,9 @@ from torchgeo.datasets import ChaBuD, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
def download_url(
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
) -> None:
shutil.copy(url, os.path.join(root, filename))
@ -34,7 +36,7 @@ class TestChaBuD:
monkeypatch.setattr(ChaBuD, 'url', url)
monkeypatch.setattr(ChaBuD, 'md5', md5)
bands, split = request.param
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ChaBuD(
root=root,
@ -70,7 +72,7 @@ class TestChaBuD:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChaBuD(str(tmp_path))
ChaBuD(tmp_path)
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -26,7 +26,7 @@ from torchgeo.datasets import (
pytest.importorskip('zipfile_deflate64')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -41,7 +41,7 @@ class TestChesapeake13:
)
monkeypatch.setattr(Chesapeake13, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return Chesapeake13(root, transforms=transforms, download=True, checksum=True)
@ -69,13 +69,13 @@ class TestChesapeake13:
url = os.path.join(
'tests', 'data', 'chesapeake', 'BAYWIDE', 'Baywide_13Class_20132014.zip'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(url, root)
Chesapeake13(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Chesapeake13(str(tmp_path), checksum=True)
Chesapeake13(tmp_path, checksum=True)
def test_plot(self, dataset: Chesapeake13) -> None:
query = dataset.bounds
@ -148,7 +148,7 @@ class TestChesapeakeCVPR:
'_files',
['de_1m_2013_extended-debuffered-test_tiles', 'spatial_index.geojson'],
)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ChesapeakeCVPR(
root,
@ -180,7 +180,7 @@ class TestChesapeakeCVPR:
ChesapeakeCVPR(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
root = str(tmp_path)
root = tmp_path
shutil.copy(
os.path.join(
'tests', 'data', 'chesapeake', 'cvpr', 'cvpr_chesapeake_landcover.zip'
@ -201,7 +201,7 @@ class TestChesapeakeCVPR:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ChesapeakeCVPR(str(tmp_path), checksum=True)
ChesapeakeCVPR(tmp_path, checksum=True)
def test_out_of_bounds_query(self, dataset: ChesapeakeCVPR) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -30,7 +30,7 @@ class TestCloudCoverDetection:
) -> CloudCoverDetection:
url = os.path.join('tests', 'data', 'ref_cloud_cover_detection_challenge_v1')
monkeypatch.setattr(CloudCoverDetection, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return CloudCoverDetection(
@ -55,7 +55,7 @@ class TestCloudCoverDetection:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CloudCoverDetection(str(tmp_path))
CloudCoverDetection(tmp_path)
def test_plot(self, dataset: CloudCoverDetection) -> None:
sample = dataset[0]

Просмотреть файл

@ -20,7 +20,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -54,7 +54,7 @@ class TestCMSGlobalMangroveCanopy:
def test_no_dataset(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CMSGlobalMangroveCanopy(str(tmp_path))
CMSGlobalMangroveCanopy(tmp_path)
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join(
@ -63,7 +63,7 @@ class TestCMSGlobalMangroveCanopy:
'cms_mangrove_canopy',
'CMS_Global_Map_Mangrove_Canopy_1665.zip',
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
CMSGlobalMangroveCanopy(root, country='Angola')
@ -73,7 +73,7 @@ class TestCMSGlobalMangroveCanopy:
) as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
CMSGlobalMangroveCanopy(str(tmp_path), country='Angola', checksum=True)
CMSGlobalMangroveCanopy(tmp_path, country='Angola', checksum=True)
def test_invalid_country(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -17,7 +17,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import COWC, COWCCounting, COWCDetection, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -46,7 +46,7 @@ class TestCOWCCounting:
'0a4daed8c5f6c4e20faa6e38636e4346',
]
monkeypatch.setattr(COWCCounting, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return COWCCounting(root, split, transforms, download=True, checksum=True)
@ -78,7 +78,7 @@ class TestCOWCCounting:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
COWCCounting(str(tmp_path))
COWCCounting(tmp_path)
def test_plot(self, dataset: COWCCounting) -> None:
x = dataset[0].copy()
@ -110,7 +110,7 @@ class TestCOWCDetection:
'dccc2257e9c4a9dde2b4f84769804046',
]
monkeypatch.setattr(COWCDetection, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return COWCDetection(root, split, transforms, download=True, checksum=True)
@ -142,7 +142,7 @@ class TestCOWCDetection:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
COWCDetection(str(tmp_path))
COWCDetection(tmp_path)
def test_plot(self, dataset: COWCDetection) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -17,7 +17,7 @@ from torchgeo.datasets import CropHarvest, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, filename: str, md5: str) -> None:
def download_url(url: str, root: str | Path, filename: str, md5: str) -> None:
shutil.copy(url, os.path.join(root, filename))
@ -42,7 +42,7 @@ class TestCropHarvest:
os.path.join('tests', 'data', 'cropharvest', 'labels.geojson'),
)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
dataset = CropHarvest(root, transforms, download=True, checksum=True)
@ -61,16 +61,16 @@ class TestCropHarvest:
assert len(dataset) == 5
def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None:
CropHarvest(root=str(tmp_path), download=False)
CropHarvest(root=tmp_path, download=False)
def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None:
feature_path = os.path.join(tmp_path, 'features')
shutil.rmtree(feature_path)
CropHarvest(root=str(tmp_path), download=True)
CropHarvest(root=tmp_path, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CropHarvest(str(tmp_path))
CropHarvest(tmp_path)
def test_plot(self, dataset: CropHarvest) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -30,7 +30,7 @@ class TestCV4AKenyaCropType:
monkeypatch.setattr(CV4AKenyaCropType, 'dates', ['20190606'])
monkeypatch.setattr(CV4AKenyaCropType, 'tile_height', 2)
monkeypatch.setattr(CV4AKenyaCropType, 'tile_width', 2)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return CV4AKenyaCropType(root, transforms=transforms, download=True)
@ -55,7 +55,7 @@ class TestCV4AKenyaCropType:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
CV4AKenyaCropType(str(tmp_path))
CV4AKenyaCropType(tmp_path)
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -28,7 +28,7 @@ class TestTropicalCyclone:
url = os.path.join('tests', 'data', 'cyclone')
monkeypatch.setattr(TropicalCyclone, 'url', url)
monkeypatch.setattr(TropicalCyclone, 'size', 2)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return TropicalCyclone(root, split, transforms, download=True)
@ -60,7 +60,7 @@ class TestTropicalCyclone:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
TropicalCyclone(str(tmp_path))
TropicalCyclone(tmp_path)
def test_plot(self, dataset: TropicalCyclone) -> None:
sample = dataset[0]

Просмотреть файл

@ -39,16 +39,14 @@ class TestDeepGlobeLandCover:
def test_extract(self, tmp_path: Path) -> None:
root = os.path.join('tests', 'data', 'deepglobelandcover')
filename = 'data.zip'
shutil.copyfile(
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
)
DeepGlobeLandCover(root=str(tmp_path))
shutil.copyfile(os.path.join(root, filename), os.path.join(tmp_path, filename))
DeepGlobeLandCover(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'data.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
DeepGlobeLandCover(root=str(tmp_path), checksum=True)
DeepGlobeLandCover(root=tmp_path, checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -56,7 +54,7 @@ class TestDeepGlobeLandCover:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
DeepGlobeLandCover(str(tmp_path))
DeepGlobeLandCover(tmp_path)
def test_plot(self, dataset: DeepGlobeLandCover) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -61,13 +61,13 @@ class TestDFC2022:
os.path.join('tests', 'data', 'dfc2022', 'val.zip'),
os.path.join(tmp_path, 'val.zip'),
)
DFC2022(root=str(tmp_path))
DFC2022(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'labeled_train.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
DFC2022(root=str(tmp_path), checksum=True)
DFC2022(root=tmp_path, checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -75,7 +75,7 @@ class TestDFC2022:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
DFC2022(str(tmp_path))
DFC2022(tmp_path)
def test_plot(self, dataset: DFC2022) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -38,7 +38,7 @@ class TestEDDMapS:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
EDDMapS(str(tmp_path))
EDDMapS(tmp_path)
def test_invalid_query(self, dataset: EDDMapS) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -24,7 +24,7 @@ from torchgeo.datasets import (
from torchgeo.samplers import RandomGeoSampler
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -51,7 +51,7 @@ class TestEnviroAtlas:
'_files',
['pittsburgh_pa-2010_1m-train_tiles-debuffered', 'spatial_index.geojson'],
)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return EnviroAtlas(
root,
@ -85,7 +85,7 @@ class TestEnviroAtlas:
EnviroAtlas(root=dataset.root, download=True)
def test_already_downloaded(self, tmp_path: Path) -> None:
root = str(tmp_path)
root = tmp_path
shutil.copy(
os.path.join('tests', 'data', 'enviroatlas', 'enviroatlas_lotp.zip'), root
)
@ -93,7 +93,7 @@ class TestEnviroAtlas:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
EnviroAtlas(str(tmp_path), checksum=True)
EnviroAtlas(tmp_path, checksum=True)
def test_out_of_bounds_query(self, dataset: EnviroAtlas) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -42,7 +42,7 @@ class TestEsri2020:
'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip',
)
monkeypatch.setattr(Esri2020, 'url', url)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return Esri2020(root, transforms=transforms, download=True, checksum=True)
@ -66,11 +66,11 @@ class TestEsri2020:
'io-lulc-model-001-v01-composite-v03-supercell-v02-clip-v01.zip',
)
shutil.copy(url, tmp_path)
Esri2020(str(tmp_path))
Esri2020(tmp_path)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Esri2020(str(tmp_path), checksum=True)
Esri2020(tmp_path, checksum=True)
def test_and(self, dataset: Esri2020) -> None:
ds = dataset & dataset

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import ETCI2021, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -48,7 +48,7 @@ class TestETCI2021:
},
}
monkeypatch.setattr(ETCI2021, 'metadata', metadata)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return ETCI2021(root, split, transforms, download=True, checksum=True)
@ -78,7 +78,7 @@ class TestETCI2021:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ETCI2021(str(tmp_path))
ETCI2021(tmp_path)
def test_plot(self, dataset: ETCI2021) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -28,7 +28,7 @@ class TestEUDEM:
monkeypatch.setattr(EUDEM, 'md5s', md5s)
zipfile = os.path.join('tests', 'data', 'eudem', 'eu_dem_v11_E30N10.zip')
shutil.copy(zipfile, tmp_path)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return EUDEM(root, transforms=transforms)
@ -42,7 +42,7 @@ class TestEUDEM:
assert len(dataset) == 1
def test_extracted_already(self, dataset: EUDEM) -> None:
assert isinstance(dataset.paths, str)
assert isinstance(dataset.paths, Path)
zipfile = os.path.join(dataset.paths, 'eu_dem_v11_E30N10.zip')
shutil.unpack_archive(zipfile, dataset.paths, 'zip')
EUDEM(dataset.paths)
@ -51,13 +51,13 @@ class TestEUDEM:
shutil.rmtree(tmp_path)
os.makedirs(tmp_path)
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
EUDEM(str(tmp_path))
EUDEM(tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'eu_dem_v11_E30N10.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
EUDEM(str(tmp_path), checksum=True)
EUDEM(tmp_path, checksum=True)
def test_and(self, dataset: EUDEM) -> None:
ds = dataset & dataset

Просмотреть файл

@ -23,7 +23,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -42,7 +42,7 @@ class TestEuroCrops:
base_url = os.path.join('tests', 'data', 'eurocrops') + os.sep
monkeypatch.setattr(EuroCrops, 'base_url', base_url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return EuroCrops(
root, classes=classes, transforms=transforms, download=True, checksum=True
@ -81,7 +81,7 @@ class TestEuroCrops:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
EuroCrops(str(tmp_path))
EuroCrops(tmp_path)
def test_invalid_query(self, dataset: EuroCrops) -> None:
query = BoundingBox(200, 200, 200, 200, 2, 2)

Просмотреть файл

@ -24,7 +24,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -61,7 +61,7 @@ class TestEuroSAT:
'test': '4af60a00fdfdf8500572ae5360694b71',
},
)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return base_class(
root=root, split=split, transforms=transforms, download=True, checksum=True
@ -90,18 +90,18 @@ class TestEuroSAT:
assert len(ds) == 4
def test_already_downloaded(self, dataset: EuroSAT, tmp_path: Path) -> None:
EuroSAT(root=str(tmp_path), download=True)
EuroSAT(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: EuroSAT, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
EuroSAT(root=str(tmp_path), download=False)
download_url(dataset.url, root=tmp_path)
EuroSAT(root=tmp_path, download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
EuroSAT(str(tmp_path))
EuroSAT(tmp_path)
def test_plot(self, dataset: EuroSAT) -> None:
x = dataset[0].copy()
@ -114,7 +114,7 @@ class TestEuroSAT:
plt.close()
def test_plot_rgb(self, dataset: EuroSAT, tmp_path: Path) -> None:
dataset = EuroSAT(root=str(tmp_path), bands=('B03',))
dataset = EuroSAT(root=tmp_path, bands=('B03',))
with pytest.raises(
RGBBandsMissingError, match='Dataset does not contain some of the RGB bands'
):

Просмотреть файл

@ -16,7 +16,9 @@ import torchgeo.datasets.utils
from torchgeo.datasets import FAIR1M, DatasetNotFoundError
def download_url(url: str, root: str, filename: str, *args: str, **kwargs: str) -> None:
def download_url(
url: str, root: str | Path, filename: str, *args: str, **kwargs: str
) -> None:
os.makedirs(root, exist_ok=True)
shutil.copy(url, os.path.join(root, filename))
@ -65,7 +67,7 @@ class TestFAIR1M:
}
monkeypatch.setattr(FAIR1M, 'urls', urls)
monkeypatch.setattr(FAIR1M, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return FAIR1M(root, split, transforms, download=True, checksum=True)
@ -89,7 +91,7 @@ class TestFAIR1M:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: FAIR1M, tmp_path: Path) -> None:
FAIR1M(root=str(tmp_path), split=dataset.split, download=True)
FAIR1M(root=tmp_path, split=dataset.split, download=True)
def test_already_downloaded_not_extracted(
self, dataset: FAIR1M, tmp_path: Path
@ -98,11 +100,11 @@ class TestFAIR1M:
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
output = os.path.join(tmp_path, filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
FAIR1M(root=tmp_path, split=dataset.split, checksum=True)
def test_corrupted(self, tmp_path: Path, dataset: FAIR1M) -> None:
md5s = tuple(['randomhash'] * len(FAIR1M.md5s[dataset.split]))
@ -111,17 +113,17 @@ class TestFAIR1M:
for filepath, url in zip(
dataset.paths[dataset.split], dataset.urls[dataset.split]
):
output = os.path.join(str(tmp_path), filepath)
output = os.path.join(tmp_path, filepath)
os.makedirs(os.path.dirname(output), exist_ok=True)
download_url(url, root=os.path.dirname(output), filename=output)
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
FAIR1M(root=str(tmp_path), split=dataset.split, checksum=True)
FAIR1M(root=tmp_path, split=dataset.split, checksum=True)
def test_not_downloaded(self, tmp_path: Path, dataset: FAIR1M) -> None:
shutil.rmtree(str(tmp_path))
shutil.rmtree(tmp_path)
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
FAIR1M(root=str(tmp_path), split=dataset.split)
FAIR1M(root=tmp_path, split=dataset.split)
def test_plot(self, dataset: FAIR1M) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, FireRisk
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -30,7 +30,7 @@ class TestFireRisk:
md5 = 'db22106d61b10d855234b4a74db921ac'
monkeypatch.setattr(FireRisk, 'md5', md5)
monkeypatch.setattr(FireRisk, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return FireRisk(root, split, transforms, download=True, checksum=True)
@ -46,18 +46,18 @@ class TestFireRisk:
assert len(dataset) == 5
def test_already_downloaded(self, dataset: FireRisk, tmp_path: Path) -> None:
FireRisk(root=str(tmp_path), download=True)
FireRisk(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: FireRisk, tmp_path: Path
) -> None:
shutil.rmtree(os.path.dirname(dataset.root))
download_url(dataset.url, root=str(tmp_path))
FireRisk(root=str(tmp_path), download=False)
download_url(dataset.url, root=tmp_path)
FireRisk(root=tmp_path, download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
FireRisk(str(tmp_path))
FireRisk(tmp_path)
def test_plot(self, dataset: FireRisk) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -15,7 +15,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, ForestDamage
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -31,7 +31,7 @@ class TestForestDamage:
monkeypatch.setattr(ForestDamage, 'url', url)
monkeypatch.setattr(ForestDamage, 'md5', md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ForestDamage(
root=root, transforms=transforms, download=True, checksum=True
@ -57,17 +57,17 @@ class TestForestDamage:
'tests', 'data', 'forestdamage', 'Data_Set_Larch_Casebearer.zip'
)
shutil.copy(url, tmp_path)
ForestDamage(root=str(tmp_path))
ForestDamage(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'Data_Set_Larch_Casebearer.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
ForestDamage(root=str(tmp_path), checksum=True)
ForestDamage(root=tmp_path, checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ForestDamage(str(tmp_path))
ForestDamage(tmp_path)
def test_plot(self, dataset: ForestDamage) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -38,7 +38,7 @@ class TestGBIF:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GBIF(str(tmp_path))
GBIF(tmp_path)
def test_invalid_query(self, dataset: GBIF) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -36,7 +36,7 @@ class CustomGeoDataset(GeoDataset):
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
paths: str | Iterable[str] | None = None,
paths: str | Path | Iterable[str | Path] | None = None,
) -> None:
super().__init__()
self.index.insert(0, tuple(bounds))
@ -172,7 +172,7 @@ class TestGeoDataset:
dataset & ds2 # type: ignore[operator]
def test_files_property_for_non_existing_file_or_dir(self, tmp_path: Path) -> None:
paths = [str(tmp_path), str(tmp_path / 'non_existing_file.tif')]
paths = [tmp_path, tmp_path / 'non_existing_file.tif']
with pytest.warns(UserWarning, match='Path was ignored.'):
assert len(CustomGeoDataset(paths=paths).files) == 0
@ -311,7 +311,7 @@ class TestRasterDataset:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RasterDataset(str(tmp_path))
RasterDataset(tmp_path)
def test_no_all_bands(self) -> None:
root = os.path.join('tests', 'data', 'sentinel2')
@ -380,7 +380,7 @@ class TestVectorDataset:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
VectorDataset(str(tmp_path))
VectorDataset(tmp_path)
class TestNonGeoDataset:

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import GID15, DatasetNotFoundError
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -30,7 +30,7 @@ class TestGID15:
monkeypatch.setattr(GID15, 'md5', md5)
url = os.path.join('tests', 'data', 'gid15', 'gid-15.zip')
monkeypatch.setattr(GID15, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return GID15(root, split, transforms, download=True, checksum=True)
@ -59,7 +59,7 @@ class TestGID15:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GID15(str(tmp_path))
GID15(tmp_path)
def test_plot(self, dataset: GID15) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -37,7 +37,7 @@ class TestGlobBiomass:
}
monkeypatch.setattr(GlobBiomass, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return GlobBiomass(root, transforms=transforms, checksum=True)
@ -55,13 +55,13 @@ class TestGlobBiomass:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
GlobBiomass(str(tmp_path), checksum=True)
GlobBiomass(tmp_path, checksum=True)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'N00E020_agb.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
GlobBiomass(str(tmp_path), checksum=True)
GlobBiomass(tmp_path, checksum=True)
def test_and(self, dataset: GlobBiomass) -> None:
ds = dataset & dataset

Просмотреть файл

@ -19,7 +19,7 @@ from torchgeo.datasets import DatasetNotFoundError, IDTReeS
pytest.importorskip('laspy', minversion='2')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -44,7 +44,7 @@ class TestIDTReeS:
}
split, task = request.param
monkeypatch.setattr(IDTReeS, 'metadata', metadata)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return IDTReeS(root, split, task, transforms, download=True, checksum=True)
@ -77,11 +77,11 @@ class TestIDTReeS:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
IDTReeS(str(tmp_path))
IDTReeS(tmp_path)
def test_not_extracted(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'idtrees', '*.zip')
root = str(tmp_path)
root = tmp_path
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
IDTReeS(root)

Просмотреть файл

@ -38,7 +38,7 @@ class TestINaturalist:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
INaturalist(str(tmp_path))
INaturalist(tmp_path)
def test_invalid_query(self, dataset: INaturalist) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -50,7 +50,7 @@ class TestInriaAerialImageLabeling:
def test_not_downloaded(self, tmp_path: str) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
InriaAerialImageLabeling(str(tmp_path))
InriaAerialImageLabeling(tmp_path)
def test_dataset_checksum(self, dataset: InriaAerialImageLabeling) -> None:
InriaAerialImageLabeling.md5 = 'randommd5hash123'

Просмотреть файл

@ -24,7 +24,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -36,7 +36,7 @@ class TestIOBench:
url = os.path.join('tests', 'data', 'iobench', '{}.tar.gz')
monkeypatch.setattr(IOBench, 'url', url)
monkeypatch.setitem(IOBench.md5s, 'preprocessed', md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return IOBench(root, transforms=transforms, download=True, checksum=True)
@ -68,14 +68,14 @@ class TestIOBench:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'iobench', '*.tar.gz')
root = str(tmp_path)
root = tmp_path
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
IOBench(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
IOBench(str(tmp_path))
IOBench(tmp_path)
def test_invalid_query(self, dataset: IOBench) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -25,7 +25,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -41,7 +41,7 @@ class TestL7Irish:
url = os.path.join('tests', 'data', 'l7irish', '{}.tar.gz')
monkeypatch.setattr(L7Irish, 'url', url)
monkeypatch.setattr(L7Irish, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return L7Irish(root, transforms=transforms, download=True, checksum=True)
@ -75,14 +75,14 @@ class TestL7Irish:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l7irish', '*.tar.gz')
root = str(tmp_path)
root = tmp_path
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
L7Irish(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
L7Irish(str(tmp_path))
L7Irish(tmp_path)
def test_plot_prediction(self, dataset: L7Irish) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -25,7 +25,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -41,7 +41,7 @@ class TestL8Biome:
url = os.path.join('tests', 'data', 'l8biome', '{}.tar.gz')
monkeypatch.setattr(L8Biome, 'url', url)
monkeypatch.setattr(L8Biome, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return L8Biome(root, transforms=transforms, download=True, checksum=True)
@ -75,14 +75,14 @@ class TestL8Biome:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'l8biome', '*.tar.gz')
root = str(tmp_path)
root = tmp_path
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
L8Biome(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
L8Biome(str(tmp_path))
L8Biome(tmp_path)
def test_plot_prediction(self, dataset: L8Biome) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -34,7 +34,7 @@ class TestLandCoverAIGeo:
monkeypatch.setattr(LandCoverAIGeo, 'md5', md5)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
monkeypatch.setattr(LandCoverAIGeo, 'url', url)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return LandCoverAIGeo(root, transforms=transforms, download=True, checksum=True)
@ -49,13 +49,13 @@ class TestLandCoverAIGeo:
def test_already_downloaded(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
root = str(tmp_path)
root = tmp_path
shutil.copy(url, root)
LandCoverAIGeo(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
LandCoverAIGeo(str(tmp_path))
LandCoverAIGeo(tmp_path)
def test_out_of_bounds_query(self, dataset: LandCoverAIGeo) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)
@ -89,7 +89,7 @@ class TestLandCoverAI:
monkeypatch.setattr(LandCoverAI, 'url', url)
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return LandCoverAI(root, split, transforms, download=True, checksum=True)
@ -115,13 +115,13 @@ class TestLandCoverAI:
sha256 = 'ecec8e871faf1bbd8ca525ca95ddc1c1f5213f40afb94599884bd85f990ebd6b'
monkeypatch.setattr(LandCoverAI, 'sha256', sha256)
url = os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip')
root = str(tmp_path)
root = tmp_path
shutil.copy(url, root)
LandCoverAI(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
LandCoverAI(str(tmp_path))
LandCoverAI(tmp_path)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -71,7 +71,7 @@ class TestLandsat8:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Landsat8(str(tmp_path))
Landsat8(tmp_path)
def test_invalid_query(self, dataset: Landsat8) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import LEVIRCD, DatasetNotFoundError, LEVIRCDPlus
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -45,7 +45,7 @@ class TestLEVIRCD:
}
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
monkeypatch.setattr(LEVIRCD, 'splits', splits)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return LEVIRCD(root, split, transforms, download=True, checksum=True)
@ -71,7 +71,7 @@ class TestLEVIRCD:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
LEVIRCD(str(tmp_path))
LEVIRCD(tmp_path)
def test_plot(self, dataset: LEVIRCD) -> None:
dataset.plot(dataset[0], suptitle='Test')
@ -93,7 +93,7 @@ class TestLEVIRCDPlus:
monkeypatch.setattr(LEVIRCDPlus, 'md5', md5)
url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip')
monkeypatch.setattr(LEVIRCDPlus, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return LEVIRCDPlus(root, split, transforms, download=True, checksum=True)
@ -119,7 +119,7 @@ class TestLEVIRCDPlus:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
LEVIRCDPlus(str(tmp_path))
LEVIRCDPlus(tmp_path)
def test_plot(self, dataset: LEVIRCDPlus) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, LoveDA
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -48,7 +48,7 @@ class TestLoveDA:
monkeypatch.setattr(LoveDA, 'info_dict', info_dict)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return LoveDA(
@ -84,7 +84,7 @@ class TestLoveDA:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
LoveDA(str(tmp_path))
LoveDA(tmp_path)
def test_plot(self, dataset: LoveDA) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -18,7 +18,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, MapInWild
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -53,7 +53,7 @@ class TestMapInWild:
urls = os.path.join('tests', 'data', 'mapinwild')
monkeypatch.setattr(MapInWild, 'url', urls)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
@ -98,12 +98,12 @@ class TestMapInWild:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MapInWild(root=str(tmp_path))
MapInWild(root=tmp_path)
def test_downloaded_not_extracted(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'mapinwild', '*', '*')
pathname_glob = glob.glob(pathname)
root = str(tmp_path)
root = tmp_path
for zipfile in pathname_glob:
shutil.copy(zipfile, root)
MapInWild(root, download=False)
@ -111,7 +111,7 @@ class TestMapInWild:
def test_corrupted(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'mapinwild', '**', '*.zip')
pathname_glob = glob.glob(pathname, recursive=True)
root = str(tmp_path)
root = tmp_path
for zipfile in pathname_glob:
shutil.copy(zipfile, root)
splitfile = os.path.join(
@ -121,10 +121,10 @@ class TestMapInWild:
with open(os.path.join(tmp_path, 'mask.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
MapInWild(root=str(tmp_path), download=True, checksum=True)
MapInWild(root=tmp_path, download=True, checksum=True)
def test_already_downloaded(self, dataset: MapInWild, tmp_path: Path) -> None:
MapInWild(root=str(tmp_path), modality=dataset.modality, download=True)
MapInWild(root=tmp_path, modality=dataset.modality, download=True)
def test_plot(self, dataset: MapInWild) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -39,18 +39,18 @@ class TestMillionAID:
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MillionAID(str(tmp_path))
MillionAID(tmp_path)
def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join('tests', 'data', 'millionaid', 'train.zip')
shutil.copy(url, tmp_path)
MillionAID(str(tmp_path))
MillionAID(tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'train.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
MillionAID(str(tmp_path), checksum=True)
MillionAID(tmp_path, checksum=True)
def test_plot(self, dataset: MillionAID) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -51,7 +51,7 @@ class TestNAIP:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
NAIP(str(tmp_path))
NAIP(tmp_path)
def test_invalid_query(self, dataset: NAIP) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -45,7 +45,7 @@ class TestNASAMarineDebris:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
md5s = ['6f4f0d2313323950e45bf3fc0c09b5de', '540cf1cf4fd2c13b609d0355abe955d7']
monkeypatch.setattr(NASAMarineDebris, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return NASAMarineDebris(root, transforms, download=True, checksum=True)
@ -63,15 +63,15 @@ class TestNASAMarineDebris:
def test_already_downloaded(
self, dataset: NASAMarineDebris, tmp_path: Path
) -> None:
NASAMarineDebris(root=str(tmp_path), download=True)
NASAMarineDebris(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: NASAMarineDebris, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
os.makedirs(str(tmp_path), exist_ok=True)
os.makedirs(tmp_path, exist_ok=True)
Collection().download(output_dir=str(tmp_path))
NASAMarineDebris(root=str(tmp_path), download=False)
NASAMarineDebris(root=tmp_path, download=False)
def test_corrupted_previously_downloaded(self, tmp_path: Path) -> None:
filenames = NASAMarineDebris.filenames
@ -79,7 +79,7 @@ class TestNASAMarineDebris:
with open(os.path.join(tmp_path, filename), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
NASAMarineDebris(root=str(tmp_path), download=False, checksum=True)
NASAMarineDebris(root=tmp_path, download=False, checksum=True)
def test_corrupted_new_download(
self, tmp_path: Path, monkeypatch: MonkeyPatch
@ -87,11 +87,11 @@ class TestNASAMarineDebris:
with pytest.raises(RuntimeError, match='Dataset checksum mismatch.'):
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_corrupted)
NASAMarineDebris(root=str(tmp_path), download=True, checksum=True)
NASAMarineDebris(root=tmp_path, download=True, checksum=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
NASAMarineDebris(str(tmp_path))
NASAMarineDebris(tmp_path)
def test_plot(self, dataset: NASAMarineDebris) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -43,7 +43,7 @@ class TestNCCM:
}
monkeypatch.setattr(NCCM, 'urls', urls)
transforms = nn.Identity()
root = str(tmp_path)
root = tmp_path
return NCCM(root, transforms=transforms, download=True, checksum=True)
def test_getitem(self, dataset: NCCM) -> None:
@ -84,7 +84,7 @@ class TestNCCM:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
NCCM(str(tmp_path))
NCCM(tmp_path)
def test_invalid_query(self, dataset: NCCM) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -42,7 +42,7 @@ class TestNLCD:
)
monkeypatch.setattr(NLCD, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return NLCD(
root,
@ -84,7 +84,7 @@ class TestNLCD:
pathname = os.path.join(
'tests', 'data', 'nlcd', 'nlcd_2019_land_cover_l48_20210604.zip'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
NLCD(root, years=[2019])
@ -93,7 +93,7 @@ class TestNLCD:
AssertionError,
match='NLCD data product only exists for the following years:',
):
NLCD(str(tmp_path), years=[1996])
NLCD(tmp_path, years=[1996])
def test_invalid_classes(self) -> None:
with pytest.raises(AssertionError):
@ -117,7 +117,7 @@ class TestNLCD:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
NLCD(str(tmp_path))
NLCD(tmp_path)
def test_invalid_query(self, dataset: NLCD) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -26,7 +26,7 @@ from torchgeo.datasets import (
class TestOpenBuildings:
@pytest.fixture
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> OpenBuildings:
root = str(tmp_path)
root = tmp_path
shutil.copy(
os.path.join('tests', 'data', 'openbuildings', 'tiles.geojson'), root
)
@ -55,7 +55,7 @@ class TestOpenBuildings:
def test_not_download(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
OpenBuildings(str(tmp_path))
OpenBuildings(tmp_path)
def test_corrupted(self, dataset: OpenBuildings, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, '000_buildings.csv.gz'), 'w') as f:

Просмотреть файл

@ -18,7 +18,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import OSCD, DatasetNotFoundError, RGBBandsMissingError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -63,7 +63,7 @@ class TestOSCD:
monkeypatch.setattr(OSCD, 'urls', urls)
bands, split = request.param
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return OSCD(
root, split, bands, transforms=transforms, download=True, checksum=True
@ -101,14 +101,14 @@ class TestOSCD:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'oscd', '*Onera*.zip')
root = str(tmp_path)
root = tmp_path
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
OSCD(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
OSCD(str(tmp_path))
OSCD(tmp_path)
def test_plot(self, dataset: OSCD) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -17,7 +17,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import PASTIS, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -38,7 +38,7 @@ class TestPASTIS:
monkeypatch.setattr(PASTIS, 'md5', md5)
url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip')
monkeypatch.setattr(PASTIS, 'url', url)
root = str(tmp_path)
root = tmp_path
folds = request.param['folds']
bands = request.param['bands']
mode = request.param['mode']
@ -75,19 +75,19 @@ class TestPASTIS:
def test_already_downloaded(self, tmp_path: Path) -> None:
url = os.path.join('tests', 'data', 'pastis', 'PASTIS-R.zip')
root = str(tmp_path)
root = tmp_path
shutil.copy(url, root)
PASTIS(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
PASTIS(str(tmp_path))
PASTIS(tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'PASTIS-R.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
PASTIS(root=str(tmp_path), checksum=True)
PASTIS(root=tmp_path, checksum=True)
def test_invalid_fold(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -15,7 +15,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, PatternNet
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -27,7 +27,7 @@ class TestPatternNet:
monkeypatch.setattr(PatternNet, 'md5', md5)
url = os.path.join('tests', 'data', 'patternnet', 'PatternNet.zip')
monkeypatch.setattr(PatternNet, 'url', url)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return PatternNet(root, transforms, download=True, checksum=True)
@ -42,18 +42,18 @@ class TestPatternNet:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: PatternNet, tmp_path: Path) -> None:
PatternNet(root=str(tmp_path), download=True)
PatternNet(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: PatternNet, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
PatternNet(root=str(tmp_path), download=False)
download_url(dataset.url, root=tmp_path)
PatternNet(root=tmp_path, download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
PatternNet(str(tmp_path))
PatternNet(tmp_path)
def test_plot(self, dataset: PatternNet) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -43,9 +43,9 @@ class TestPotsdam2D:
root = os.path.join('tests', 'data', 'potsdam')
for filename in ['4_Ortho_RGBIR.zip', '5_Labels_all.zip']:
shutil.copyfile(
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
os.path.join(root, filename), os.path.join(tmp_path, filename)
)
Potsdam2D(root=str(tmp_path))
Potsdam2D(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, '4_Ortho_RGBIR.zip'), 'w') as f:
@ -53,7 +53,7 @@ class TestPotsdam2D:
with open(os.path.join(tmp_path, '5_Labels_all.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
Potsdam2D(root=str(tmp_path), checksum=True)
Potsdam2D(root=tmp_path, checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -61,7 +61,7 @@ class TestPotsdam2D:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Potsdam2D(str(tmp_path))
Potsdam2D(tmp_path)
def test_plot(self, dataset: Potsdam2D) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -50,7 +50,7 @@ class TestPRISMA:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
PRISMA(str(tmp_path))
PRISMA(tmp_path)
def test_invalid_query(self, dataset: PRISMA) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -18,7 +18,7 @@ from torchgeo.datasets import DatasetNotFoundError, QuakeSet
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -32,7 +32,7 @@ class TestQuakeSet:
md5 = '127d0d6a1f82d517129535f50053a4c9'
monkeypatch.setattr(QuakeSet, 'md5', md5)
monkeypatch.setattr(QuakeSet, 'url', url)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return QuakeSet(
@ -50,11 +50,11 @@ class TestQuakeSet:
assert len(dataset) == 8
def test_already_downloaded(self, dataset: QuakeSet, tmp_path: Path) -> None:
QuakeSet(root=str(tmp_path), download=True)
QuakeSet(root=tmp_path, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
QuakeSet(str(tmp_path))
QuakeSet(tmp_path)
def test_plot(self, dataset: QuakeSet) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -15,7 +15,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, ReforesTree
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -31,7 +31,7 @@ class TestReforesTree:
monkeypatch.setattr(ReforesTree, 'url', url)
monkeypatch.setattr(ReforesTree, 'md5', md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ReforesTree(
root=root, transforms=transforms, download=True, checksum=True
@ -57,17 +57,17 @@ class TestReforesTree:
def test_not_extracted(self, tmp_path: Path) -> None:
url = os.path.join('tests', 'data', 'reforestree', 'reforesTree.zip')
shutil.copy(url, tmp_path)
ReforesTree(root=str(tmp_path))
ReforesTree(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(os.path.join(tmp_path, 'reforesTree.zip'), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
ReforesTree(root=str(tmp_path), checksum=True)
ReforesTree(root=tmp_path, checksum=True)
def test_not_found(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ReforesTree(str(tmp_path))
ReforesTree(tmp_path)
def test_plot(self, dataset: ReforesTree) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -18,7 +18,7 @@ from torchgeo.datasets import RESISC45, DatasetNotFoundError
pytest.importorskip('rarfile', minversion='4')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -52,7 +52,7 @@ class TestRESISC45:
'test': '7760b1960c9a3ff46fb985810815e14d',
},
)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return RESISC45(root, split, transforms, download=True, checksum=True)
@ -68,18 +68,18 @@ class TestRESISC45:
assert len(dataset) == 9
def test_already_downloaded(self, dataset: RESISC45, tmp_path: Path) -> None:
RESISC45(root=str(tmp_path), download=True)
RESISC45(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: RESISC45, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
RESISC45(root=str(tmp_path), download=False)
download_url(dataset.url, root=tmp_path)
RESISC45(root=tmp_path, download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RESISC45(str(tmp_path))
RESISC45(tmp_path)
def test_plot(self, dataset: RESISC45) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -33,7 +33,7 @@ class TestRwandaFieldBoundary:
monkeypatch.setattr(RwandaFieldBoundary, 'url', url)
monkeypatch.setattr(RwandaFieldBoundary, 'splits', {'train': 1, 'test': 1})
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return RwandaFieldBoundary(root, split, transforms=transforms, download=True)
@ -60,7 +60,7 @@ class TestRwandaFieldBoundary:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
RwandaFieldBoundary(str(tmp_path))
RwandaFieldBoundary(tmp_path)
def test_invalid_bands(self) -> None:
with pytest.raises(AssertionError):

Просмотреть файл

@ -18,7 +18,9 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, SeasoNet
def download_url(url: str, root: str, md5: str, *args: str, **kwargs: str) -> None:
def download_url(
url: str, root: str | Path, md5: str, *args: str, **kwargs: str
) -> None:
shutil.copy(url, root)
torchgeo.datasets.utils.check_integrity(
os.path.join(root, os.path.basename(url)), md5
@ -95,7 +97,7 @@ class TestSeasoNet:
'url',
os.path.join('tests', 'data', 'seasonet', 'meta.csv'),
)
root = str(tmp_path)
root = tmp_path
split, seasons, bands, grids, concat_seasons = request.param
transforms = nn.Identity()
return SeasoNet(
@ -141,14 +143,14 @@ class TestSeasoNet:
def test_already_downloaded(self, tmp_path: Path) -> None:
paths = os.path.join('tests', 'data', 'seasonet', '*.*')
root = str(tmp_path)
root = tmp_path
for path in glob.iglob(paths):
shutil.copy(path, root)
SeasoNet(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SeasoNet(str(tmp_path), download=False)
SeasoNet(tmp_path, download=False)
def test_out_of_bounds(self, dataset: SeasoNet) -> None:
with pytest.raises(IndexError):

Просмотреть файл

@ -22,7 +22,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -56,7 +56,7 @@ class TestSeasonalContrastS2:
monkeypatch.setitem(
SeasonalContrastS2.metadata['1m'], 'md5', '3bb3fcf90f5de7d5781ce0cb85fd20af'
)
root = str(tmp_path)
root = tmp_path
version, seasons, bands = request.param
transforms = nn.Identity()
return SeasonalContrastS2(
@ -88,7 +88,7 @@ class TestSeasonalContrastS2:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'seco', '*.zip')
root = str(tmp_path)
root = tmp_path
for zipfile in glob.iglob(pathname):
shutil.copy(zipfile, root)
SeasonalContrastS2(root)
@ -103,7 +103,7 @@ class TestSeasonalContrastS2:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SeasonalContrastS2(str(tmp_path))
SeasonalContrastS2(tmp_path)
def test_plot(self, dataset: SeasonalContrastS2) -> None:
x = dataset[0]

Просмотреть файл

@ -66,10 +66,10 @@ class TestSEN12MS:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SEN12MS(str(tmp_path), checksum=True)
SEN12MS(tmp_path, checksum=True)
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SEN12MS(str(tmp_path), checksum=False)
SEN12MS(tmp_path, checksum=False)
def test_check_integrity_light(self) -> None:
root = os.path.join('tests', 'data', 'sen12ms')

Просмотреть файл

@ -70,7 +70,7 @@ class TestSentinel1:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Sentinel1(str(tmp_path))
Sentinel1(tmp_path)
def test_empty_bands(self) -> None:
with pytest.raises(AssertionError, match="'bands' cannot be an empty list"):
@ -132,7 +132,7 @@ class TestSentinel2:
def test_no_data(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Sentinel2(str(tmp_path))
Sentinel2(tmp_path)
def test_plot(self, dataset: Sentinel2) -> None:
x = dataset[dataset.bounds]

Просмотреть файл

@ -19,7 +19,7 @@ from torchgeo.datasets import SKIPPD, DatasetNotFoundError
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -40,7 +40,7 @@ class TestSKIPPD:
url = os.path.join('tests', 'data', 'skippd', '{}')
monkeypatch.setattr(SKIPPD, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SKIPPD(
root=root,
@ -59,7 +59,7 @@ class TestSKIPPD:
pathname = os.path.join(
'tests', 'data', 'skippd', f'2017_2019_images_pv_processed_{task}.zip'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
SKIPPD(root=root, task=task)
@ -84,7 +84,7 @@ class TestSKIPPD:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SKIPPD(str(tmp_path))
SKIPPD(tmp_path)
def test_plot(self, dataset: SKIPPD) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -58,7 +58,7 @@ class TestSo2Sat:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
So2Sat(str(tmp_path))
So2Sat(tmp_path)
def test_plot(self, dataset: So2Sat) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -52,7 +52,7 @@ class TestSouthAfricaCropType:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SouthAfricaCropType(str(tmp_path))
SouthAfricaCropType(tmp_path)
def test_plot(self) -> None:
path = os.path.join('tests', 'data', 'south_africa_crop_type')

Просмотреть файл

@ -21,7 +21,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -37,7 +37,7 @@ class TestSouthAmericaSoybean:
)
monkeypatch.setattr(SouthAmericaSoybean, 'url', url)
root = str(tmp_path)
root = tmp_path
return SouthAmericaSoybean(
paths=root,
years=[2002, 2021],
@ -70,7 +70,7 @@ class TestSouthAmericaSoybean:
pathname = os.path.join(
'tests', 'data', 'south_america_soybean', 'SouthAmerica_Soybean_2002.tif'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
SouthAmericaSoybean(root)
@ -89,7 +89,7 @@ class TestSouthAmericaSoybean:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SouthAmericaSoybean(str(tmp_path))
SouthAmericaSoybean(tmp_path)
def test_invalid_query(self, dataset: SouthAmericaSoybean) -> None:
query = BoundingBox(0, 0, 0, 0, 0, 0)

Просмотреть файл

@ -68,7 +68,7 @@ class TestSpaceNet1:
# Refer https://github.com/python/mypy/issues/1032
monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet1(
root, image=request.param, transforms=transforms, download=True, api_key=''
@ -93,7 +93,7 @@ class TestSpaceNet1:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet1(str(tmp_path))
SpaceNet1(tmp_path)
def test_plot(self, dataset: SpaceNet1) -> None:
x = dataset[0].copy()
@ -118,7 +118,7 @@ class TestSpaceNet2:
}
monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet2(
root,
@ -149,7 +149,7 @@ class TestSpaceNet2:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet2(str(tmp_path))
SpaceNet2(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123'
@ -177,7 +177,7 @@ class TestSpaceNet3:
}
monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet3(
root,
@ -209,7 +209,7 @@ class TestSpaceNet3:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet3(str(tmp_path))
SpaceNet3(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123'
@ -240,7 +240,7 @@ class TestSpaceNet4:
test_angles = ['nadir', 'off-nadir', 'very-off-nadir']
monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet4(
root,
@ -273,7 +273,7 @@ class TestSpaceNet4:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet4(str(tmp_path))
SpaceNet4(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123'
@ -303,7 +303,7 @@ class TestSpaceNet5:
}
monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet5(
root,
@ -335,7 +335,7 @@ class TestSpaceNet5:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet5(str(tmp_path))
SpaceNet5(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123'
@ -359,7 +359,7 @@ class TestSpaceNet6:
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet6:
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet6(
root, image=request.param, transforms=transforms, download=True, api_key=''
@ -405,7 +405,7 @@ class TestSpaceNet7:
}
monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return SpaceNet7(
root, split=request.param, transforms=transforms, download=True, api_key=''
@ -429,7 +429,7 @@ class TestSpaceNet7:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet7(str(tmp_path))
SpaceNet7(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123'

Просмотреть файл

@ -18,7 +18,7 @@ import torchgeo
from torchgeo.datasets import SSL4EOL, SSL4EOS12, DatasetNotFoundError
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -61,7 +61,7 @@ class TestSSL4EOL:
}
monkeypatch.setattr(SSL4EOL, 'checksums', checksums)
root = str(tmp_path)
root = tmp_path
split, seasons = request.param
transforms = nn.Identity()
return SSL4EOL(root, split, seasons, transforms, download=True, checksum=True)
@ -88,14 +88,14 @@ class TestSSL4EOL:
def test_already_downloaded(self, dataset: SSL4EOL, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'ssl4eo', 'l', '*.tar.gz*')
root = str(tmp_path)
root = tmp_path
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
SSL4EOL(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SSL4EOL(str(tmp_path))
SSL4EOL(tmp_path)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -148,7 +148,7 @@ class TestSSL4EOS12:
os.path.join('tests', 'data', 'ssl4eo', 's12', filename),
tmp_path / filename,
)
SSL4EOS12(str(tmp_path))
SSL4EOS12(tmp_path)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -156,7 +156,7 @@ class TestSSL4EOS12:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SSL4EOS12(str(tmp_path))
SSL4EOS12(tmp_path)
def test_plot(self, dataset: SSL4EOS12) -> None:
sample = dataset[0]

Просмотреть файл

@ -25,7 +25,7 @@ from torchgeo.datasets import (
)
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -43,7 +43,7 @@ class TestSSL4EOLBenchmark:
monkeypatch.setattr(
torchgeo.datasets.ssl4eo_benchmark, 'download_url', download_url
)
root = str(tmp_path)
root = tmp_path
url = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '{}.tar.gz')
monkeypatch.setattr(SSL4EOLBenchmark, 'url', url)
@ -140,14 +140,14 @@ class TestSSL4EOLBenchmark:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'ssl4eo_benchmark_landsat', '*.tar.gz')
root = str(tmp_path)
root = tmp_path
for tarfile in glob.iglob(pathname):
shutil.copy(tarfile, root)
SSL4EOLBenchmark(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SSL4EOLBenchmark(str(tmp_path))
SSL4EOLBenchmark(tmp_path)
def test_plot(self, dataset: SSL4EOLBenchmark) -> None:
sample = dataset[0]

Просмотреть файл

@ -16,7 +16,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, SustainBenchCropYield
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -34,7 +34,7 @@ class TestSustainBenchCropYield:
url = os.path.join('tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip')
monkeypatch.setattr(SustainBenchCropYield, 'url', url)
monkeypatch.setattr(plt, 'show', lambda *args: None)
root = str(tmp_path)
root = tmp_path
split = request.param
countries = ['argentina', 'brazil', 'usa']
transforms = nn.Identity()
@ -49,7 +49,7 @@ class TestSustainBenchCropYield:
pathname = os.path.join(
'tests', 'data', 'sustainbench_crop_yield', 'soybeans.zip'
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
SustainBenchCropYield(root)
@ -72,7 +72,7 @@ class TestSustainBenchCropYield:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SustainBenchCropYield(str(tmp_path))
SustainBenchCropYield(tmp_path)
def test_plot(self, dataset: SustainBenchCropYield) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -17,7 +17,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, UCMerced
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -51,7 +51,7 @@ class TestUCMerced:
'test': 'a01fa9f13333bb176fc1bfe26ff4c711',
},
)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return UCMerced(root, split, transforms, download=True, checksum=True)
@ -71,18 +71,18 @@ class TestUCMerced:
assert len(ds) == 8
def test_already_downloaded(self, dataset: UCMerced, tmp_path: Path) -> None:
UCMerced(root=str(tmp_path), download=True)
UCMerced(root=tmp_path, download=True)
def test_already_downloaded_not_extracted(
self, dataset: UCMerced, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=str(tmp_path))
UCMerced(root=str(tmp_path), download=False)
download_url(dataset.url, root=tmp_path)
UCMerced(root=tmp_path, download=False)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
UCMerced(str(tmp_path))
UCMerced(tmp_path)
def test_plot(self, dataset: UCMerced) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -17,7 +17,7 @@ import torchgeo.datasets.utils
from torchgeo.datasets import DatasetNotFoundError, USAVars
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -73,7 +73,7 @@ class TestUSAVars:
}
monkeypatch.setattr(USAVars, 'split_metadata', split_metadata)
root = str(tmp_path)
root = tmp_path
split, labels = request.param
transforms = nn.Identity()
@ -109,7 +109,7 @@ class TestUSAVars:
def test_already_downloaded(self, tmp_path: Path) -> None:
pathname = os.path.join('tests', 'data', 'usavars', 'uar.zip')
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
csvs = [
'elevation.csv',
@ -130,7 +130,7 @@ class TestUSAVars:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
USAVars(str(tmp_path))
USAVars(tmp_path)
def test_plot(self, dataset: USAVars) -> None:
dataset.plot(dataset[0], suptitle='Test')

Просмотреть файл

@ -65,7 +65,7 @@ def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
return Collection()
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -85,7 +85,7 @@ def test_extract_archive(src: str, tmp_path: Path) -> None:
pytest.importorskip('rarfile', minversion='4')
if src.startswith('chesapeake'):
pytest.importorskip('zipfile_deflate64')
extract_archive(os.path.join('tests', 'data', src), str(tmp_path))
extract_archive(os.path.join('tests', 'data', src), tmp_path)
def test_unsupported_scheme() -> None:
@ -98,8 +98,7 @@ def test_unsupported_scheme() -> None:
def test_download_and_extract_archive(tmp_path: Path, monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
download_and_extract_archive(
os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'),
str(tmp_path),
os.path.join('tests', 'data', 'landcoverai', 'landcover.ai.v1.zip'), tmp_path
)
@ -108,7 +107,7 @@ def test_download_radiant_mlhub_dataset(
) -> None:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
download_radiant_mlhub_dataset('', str(tmp_path))
download_radiant_mlhub_dataset('', tmp_path)
def test_download_radiant_mlhub_collection(
@ -116,7 +115,7 @@ def test_download_radiant_mlhub_collection(
) -> None:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
download_radiant_mlhub_collection('', str(tmp_path))
download_radiant_mlhub_collection('', tmp_path)
class TestBoundingBox:

Просмотреть файл

@ -49,9 +49,9 @@ class TestVaihingen2D:
]
for filename in filenames:
shutil.copyfile(
os.path.join(root, filename), os.path.join(str(tmp_path), filename)
os.path.join(root, filename), os.path.join(tmp_path, filename)
)
Vaihingen2D(root=str(tmp_path))
Vaihingen2D(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
filenames = [
@ -62,7 +62,7 @@ class TestVaihingen2D:
with open(os.path.join(tmp_path, filename), 'w') as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
Vaihingen2D(root=str(tmp_path), checksum=True)
Vaihingen2D(root=tmp_path, checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -70,7 +70,7 @@ class TestVaihingen2D:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
Vaihingen2D(str(tmp_path))
Vaihingen2D(tmp_path)
def test_plot(self, dataset: Vaihingen2D) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -20,7 +20,7 @@ pytest.importorskip('pycocotools')
pytest.importorskip('rarfile', minversion='4')
def download_url(url: str, root: str, *args: str) -> None:
def download_url(url: str, root: str | Path, *args: str) -> None:
shutil.copy(url, root)
@ -39,7 +39,7 @@ class TestVHR10:
monkeypatch.setitem(VHR10.target_meta, 'url', url)
md5 = '567c4cd8c12624864ff04865de504c58'
monkeypatch.setitem(VHR10.target_meta, 'md5', md5)
root = str(tmp_path)
root = tmp_path
split = request.param
transforms = nn.Identity()
return VHR10(root, split, transforms, download=True, checksum=True)
@ -78,7 +78,7 @@ class TestVHR10:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
VHR10(str(tmp_path))
VHR10(tmp_path)
def test_plot(self, dataset: VHR10) -> None:
pytest.importorskip('skimage', minversion='0.19')

Просмотреть файл

@ -37,7 +37,7 @@ class TestWesternUSALiveFuelMoisture:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
md5 = 'ecbc9269dd27c4efe7aa887960054351'
monkeypatch.setattr(WesternUSALiveFuelMoisture, 'md5', md5)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return WesternUSALiveFuelMoisture(
root, transforms=transforms, download=True, api_key='', checksum=True
@ -60,13 +60,13 @@ class TestWesternUSALiveFuelMoisture:
'western_usa_live_fuel_moisture',
'su_sar_moisture_content.tar.gz',
)
root = str(tmp_path)
root = tmp_path
shutil.copy(pathname, root)
WesternUSALiveFuelMoisture(root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
WesternUSALiveFuelMoisture(str(tmp_path))
WesternUSALiveFuelMoisture(tmp_path)
def test_invalid_features(self, dataset: WesternUSALiveFuelMoisture) -> None:
with pytest.raises(AssertionError, match='Invalid input variable name.'):

Просмотреть файл

@ -61,7 +61,7 @@ class TestXView2:
),
os.path.join(tmp_path, 'test_images_labels_targets.tar.gz'),
)
XView2(root=str(tmp_path))
XView2(root=tmp_path)
def test_corrupted(self, tmp_path: Path) -> None:
with open(
@ -73,7 +73,7 @@ class TestXView2:
) as f:
f.write('bad')
with pytest.raises(RuntimeError, match='Dataset found, but corrupted.'):
XView2(root=str(tmp_path), checksum=True)
XView2(root=tmp_path, checksum=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -81,7 +81,7 @@ class TestXView2:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
XView2(str(tmp_path))
XView2(tmp_path)
def test_plot(self, dataset: XView2) -> None:
x = dataset[0].copy()

Просмотреть файл

@ -17,7 +17,7 @@ from torchgeo.datasets import DatasetNotFoundError, RGBBandsMissingError, ZueriC
pytest.importorskip('h5py', minversion='3.6')
def download_url(url: str, root: str, *args: str, **kwargs: str) -> None:
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
@ -33,7 +33,7 @@ class TestZueriCrop:
md5s = ['1635231df67f3d25f4f1e62c98e221a4', '5118398c7a5bbc246f5f6bb35d8d529b']
monkeypatch.setattr(ZueriCrop, 'urls', urls)
monkeypatch.setattr(ZueriCrop, 'md5s', md5s)
root = str(tmp_path)
root = tmp_path
transforms = nn.Identity()
return ZueriCrop(root=root, transforms=transforms, download=True, checksum=True)
@ -67,7 +67,7 @@ class TestZueriCrop:
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
ZueriCrop(str(tmp_path))
ZueriCrop(tmp_path)
def test_invalid_bands(self) -> None:
with pytest.raises(ValueError):

Просмотреть файл

@ -17,7 +17,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_and_extract_archive, lazy_import
from .utils import Path, download_and_extract_archive, lazy_import
class ADVANCE(NonGeoDataset):
@ -88,7 +88,7 @@ class ADVANCE(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
@ -151,7 +151,7 @@ class ADVANCE(NonGeoDataset):
"""
return len(self.files)
def _load_files(self, root: str) -> list[dict[str, str]]:
def _load_files(self, root: Path) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
@ -169,7 +169,7 @@ class ADVANCE(NonGeoDataset):
]
return files
def _load_image(self, path: str) -> Tensor:
def _load_image(self, path: Path) -> Tensor:
"""Load a single image.
Args:
@ -185,7 +185,7 @@ class ADVANCE(NonGeoDataset):
tensor = tensor.permute((2, 0, 1))
return tensor
def _load_target(self, path: str) -> Tensor:
def _load_target(self, path: Path) -> Tensor:
"""Load the target audio for a single image.
Args:

Просмотреть файл

@ -5,6 +5,7 @@
import json
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
@ -14,7 +15,7 @@ from rasterio.crs import CRS
from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import download_url
from .utils import Path, download_url
class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
@ -57,7 +58,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
@ -105,7 +106,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
def _download(self) -> None:
"""Download the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
download_url(self.url, self.paths, self.base_filename)
with open(os.path.join(self.paths, self.base_filename)) as f:

Просмотреть файл

@ -4,6 +4,7 @@
"""AgriFieldNet India Challenge dataset."""
import os
import pathlib
import re
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
@ -16,7 +17,7 @@ from torch import Tensor
from .errors import RGBBandsMissingError
from .geo import RasterDataset
from .utils import BoundingBox
from .utils import BoundingBox, Path
class AgriFieldNet(RasterDataset):
@ -115,7 +116,7 @@ class AgriFieldNet(RasterDataset):
def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
classes: list[int] = list(cmap.keys()),
bands: Sequence[str] = all_bands,
@ -167,10 +168,10 @@ class AgriFieldNet(RasterDataset):
Returns:
data, label, and field ids at that index
"""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
filepaths = cast(list[Path], [hit.object for hit in hits])
if not filepaths:
raise IndexError(

Просмотреть файл

@ -12,6 +12,7 @@ from rasterio.crs import CRS
from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import Path
class AsterGDEM(RasterDataset):
@ -47,7 +48,7 @@ class AsterGDEM(RasterDataset):
def __init__(
self,
paths: str | list[str] = 'data',
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,

Просмотреть файл

@ -19,7 +19,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which
class BeninSmallHolderCashews(NonGeoDataset):
@ -163,7 +163,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
chip_size: int = 256,
stride: int = 128,
bands: Sequence[str] = all_bands,

Просмотреть файл

@ -18,7 +18,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive, sort_sentinel2_bands
from .utils import Path, download_url, extract_archive, sort_sentinel2_bands
class BigEarthNet(NonGeoDataset):
@ -267,7 +267,7 @@ class BigEarthNet(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: str = 'all',
num_classes: int = 19,
@ -486,7 +486,7 @@ class BigEarthNet(NonGeoDataset):
filepath = os.path.join(self.root, filename)
self._extract(filepath)
def _download(self, url: str, filename: str, md5: str) -> None:
def _download(self, url: str, filename: Path, md5: str) -> None:
"""Download the dataset.
Args:
@ -499,13 +499,13 @@ class BigEarthNet(NonGeoDataset):
url, self.root, filename=filename, md5=md5 if self.checksum else None
)
def _extract(self, filepath: str) -> None:
def _extract(self, filepath: Path) -> None:
"""Extract the dataset.
Args:
filepath: path to file to be extracted
"""
if not filepath.endswith('.csv'):
if not str(filepath).endswith('.csv'):
extract_archive(filepath)
def _onehot_labels_to_names(

Просмотреть файл

@ -16,7 +16,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import percentile_normalization
from .utils import Path, percentile_normalization
class BioMassters(NonGeoDataset):
@ -57,7 +57,7 @@ class BioMassters(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
sensors: Sequence[str] = ['S1', 'S2'],
as_time_series: bool = False,
@ -167,7 +167,7 @@ class BioMassters(NonGeoDataset):
"""
return len(self.df['num_index'].unique())
def _load_input(self, filenames: list[str]) -> Tensor:
def _load_input(self, filenames: list[Path]) -> Tensor:
"""Load the input imagery at the index.
Args:
@ -186,7 +186,7 @@ class BioMassters(NonGeoDataset):
arr = np.concatenate(arr_list, axis=0)
return torch.tensor(arr.astype(np.int32))
def _load_target(self, filename: str) -> Tensor:
def _load_target(self, filename: Path) -> Tensor:
"""Load the target mask at the index.
Args:

Просмотреть файл

@ -4,6 +4,7 @@
"""Canadian Building Footprints dataset."""
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
@ -13,7 +14,7 @@ from rasterio.crs import CRS
from .errors import DatasetNotFoundError
from .geo import VectorDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import Path, check_integrity, download_and_extract_archive
class CanadianBuildingFootprints(VectorDataset):
@ -62,7 +63,7 @@ class CanadianBuildingFootprints(VectorDataset):
def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float = 0.00001,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
@ -104,7 +105,7 @@ class CanadianBuildingFootprints(VectorDataset):
Returns:
True if dataset files are found and/or MD5s match, else False
"""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
filepath = os.path.join(self.paths, prov_terr + '.zip')
if not check_integrity(filepath, md5 if self.checksum else None):
@ -116,7 +117,7 @@ class CanadianBuildingFootprints(VectorDataset):
if self._check_integrity():
print('Files already downloaded and verified')
return
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for prov_terr, md5 in zip(self.provinces_territories, self.md5s):
download_and_extract_archive(
self.url + prov_terr + '.zip',

Просмотреть файл

@ -4,6 +4,7 @@
"""CDL dataset."""
import os
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
@ -14,7 +15,7 @@ from rasterio.crs import CRS
from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, Path, download_url, extract_archive
class CDL(RasterDataset):
@ -207,7 +208,7 @@ class CDL(RasterDataset):
def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
years: list[int] = [2023],
@ -294,7 +295,7 @@ class CDL(RasterDataset):
# Check if the zip files have already been downloaded
exists = []
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for year in self.years:
pathname = os.path.join(
self.paths, self.zipfile_glob.replace('*', str(year))
@ -327,7 +328,7 @@ class CDL(RasterDataset):
def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
for year in self.years:
zipfile_name = self.zipfile_glob.replace('*', str(year))
pathname = os.path.join(self.paths, zipfile_name)

Просмотреть файл

@ -14,7 +14,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, lazy_import, percentile_normalization
from .utils import Path, download_url, lazy_import, percentile_normalization
class ChaBuD(NonGeoDataset):
@ -75,7 +75,7 @@ class ChaBuD(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: list[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,

Просмотреть файл

@ -5,6 +5,7 @@
import abc
import os
import pathlib
import sys
from collections.abc import Callable, Iterable, Sequence
from typing import Any, cast
@ -26,7 +27,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import GeoDataset, RasterDataset
from .nlcd import NLCD
from .utils import BoundingBox, download_url, extract_archive
from .utils import BoundingBox, Path, download_url, extract_archive
class Chesapeake(RasterDataset, abc.ABC):
@ -91,7 +92,7 @@ class Chesapeake(RasterDataset, abc.ABC):
def __init__(
self,
paths: str | Iterable[str] = 'data',
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
@ -145,7 +146,7 @@ class Chesapeake(RasterDataset, abc.ABC):
return
# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
if os.path.exists(os.path.join(self.paths, self.zipfile)):
self._extract()
return
@ -164,7 +165,7 @@ class Chesapeake(RasterDataset, abc.ABC):
def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
extract_archive(os.path.join(self.paths, self.zipfile))
def plot(
@ -510,7 +511,7 @@ class ChesapeakeCVPR(GeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
splits: Sequence[str] = ['de-train'],
layers: Sequence[str] = ['naip-new', 'lc'],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
@ -668,7 +669,7 @@ class ChesapeakeCVPR(GeoDataset):
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
def exists(filename: str) -> bool:
def exists(filename: Path) -> bool:
return os.path.exists(os.path.join(self.root, filename))
# Check if the extracted files already exist

Просмотреть файл

@ -16,7 +16,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which
class CloudCoverDetection(NonGeoDataset):
@ -61,7 +61,7 @@ class CloudCoverDetection(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
bands: Sequence[str] = all_bands,
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,

Просмотреть файл

@ -4,6 +4,7 @@
"""CMS Global Mangrove Canopy dataset."""
import os
import pathlib
from collections.abc import Callable
from typing import Any
@ -13,7 +14,7 @@ from rasterio.crs import CRS
from .errors import DatasetNotFoundError
from .geo import RasterDataset
from .utils import check_integrity, extract_archive
from .utils import Path, check_integrity, extract_archive
class CMSGlobalMangroveCanopy(RasterDataset):
@ -169,7 +170,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
def __init__(
self,
paths: str | list[str] = 'data',
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
measurement: str = 'agb',
@ -228,7 +229,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
return
# Check if the zip file has already been downloaded
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
pathname = os.path.join(self.paths, self.zipfile)
if os.path.exists(pathname):
if self.checksum and not check_integrity(pathname, self.md5):
@ -240,7 +241,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
def _extract(self) -> None:
"""Extract the dataset."""
assert isinstance(self.paths, str)
assert isinstance(self.paths, str | pathlib.Path)
pathname = os.path.join(self.paths, self.zipfile)
extract_archive(pathname)

Просмотреть файл

@ -18,7 +18,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import check_integrity, download_and_extract_archive
from .utils import Path, check_integrity, download_and_extract_archive
class COWC(NonGeoDataset, abc.ABC):
@ -65,7 +65,7 @@ class COWC(NonGeoDataset, abc.ABC):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,

Просмотреть файл

@ -17,7 +17,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import download_url, extract_archive, lazy_import
from .utils import Path, download_url, extract_archive, lazy_import
class CropHarvest(NonGeoDataset):
@ -96,7 +96,7 @@ class CropHarvest(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
@ -157,7 +157,7 @@ class CropHarvest(NonGeoDataset):
"""
return len(self.files)
def _load_features(self, root: str) -> list[dict[str, str]]:
def _load_features(self, root: Path) -> list[dict[str, str]]:
"""Return the paths of the files in the dataset.
Args:
@ -181,7 +181,7 @@ class CropHarvest(NonGeoDataset):
files.append(dict(chip=chip_path, index=index, dataset=dataset))
return files
def _load_labels(self, root: str) -> pd.DataFrame:
def _load_labels(self, root: Path) -> pd.DataFrame:
"""Return the paths of the files in the dataset.
Args:
@ -196,7 +196,7 @@ class CropHarvest(NonGeoDataset):
df = pd.json_normalize(data['features'])
return df
def _load_array(self, path: str) -> Tensor:
def _load_array(self, path: Path) -> Tensor:
"""Load an individual single pixel time series.
Args:

Просмотреть файл

@ -16,7 +16,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which
class CV4AKenyaCropType(NonGeoDataset):
@ -104,7 +104,7 @@ class CV4AKenyaCropType(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
chip_size: int = 256,
stride: int = 128,
bands: Sequence[str] = all_bands,

Просмотреть файл

@ -18,7 +18,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import which
from .utils import Path, which
class TropicalCyclone(NonGeoDataset):
@ -53,7 +53,7 @@ class TropicalCyclone(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,

Просмотреть файл

@ -16,6 +16,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import (
Path,
check_integrity,
draw_semantic_segmentation_masks,
extract_archive,
@ -100,7 +101,7 @@ class DeepGlobeLandCover(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,

Просмотреть файл

@ -18,7 +18,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import check_integrity, extract_archive, percentile_normalization
from .utils import Path, check_integrity, extract_archive, percentile_normalization
class DFC2022(NonGeoDataset):
@ -137,7 +137,7 @@ class DFC2022(NonGeoDataset):
def __init__(
self,
root: str = 'data',
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
@ -224,7 +224,7 @@ class DFC2022(NonGeoDataset):
return files
def _load_image(self, path: str, shape: Sequence[int] | None = None) -> Tensor:
def _load_image(self, path: Path, shape: Sequence[int] | None = None) -> Tensor:
"""Load a single image.
Args:
@ -241,7 +241,7 @@ class DFC2022(NonGeoDataset):
tensor = torch.from_numpy(array)
return tensor
def _load_target(self, path: str) -> Tensor:
def _load_target(self, path: Path) -> Tensor:
"""Load the target mask for a single image.
Args:

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше