LEVIRCD: data module tests without download (#2231)
* LEVIRCD: data module tests without download * Skip args
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
|
@ -5,7 +5,6 @@
|
|||
|
||||
import hashlib
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
|
@ -32,8 +31,11 @@ if __name__ == '__main__':
|
|||
directories = ['A', 'B', 'label']
|
||||
|
||||
for split, filename in zip(splits, filenames):
|
||||
if os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
for directory in directories:
|
||||
os.mkdir(directory)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
|
||||
for i in range(2):
|
||||
path = os.path.join('A', f'{split}_{i}.png')
|
||||
|
@ -51,9 +53,6 @@ if __name__ == '__main__':
|
|||
for file in os.listdir(directory):
|
||||
f.write(os.path.join(directory, file))
|
||||
|
||||
for directory in directories:
|
||||
shutil.rmtree(directory)
|
||||
|
||||
# compute checksum
|
||||
with open(filename, 'rb') as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
|
|
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
Двоичные данные
tests/data/levircd/levircd/test.zip
Двоичные данные
tests/data/levircd/levircd/train.zip
Двоичные данные
tests/data/levircd/levircd/val.zip
Двоичные данные
tests/data/levircd/levircdplus/LEVIR-CD+.zip
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 69 B |
После Ширина: | Высота: | Размер: 67 B |
После Ширина: | Высота: | Размер: 67 B |
|
@ -57,5 +57,3 @@ if __name__ == '__main__':
|
|||
with open(f'{root}.zip', 'rb') as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f'{root}.zip: {md5}')
|
||||
|
||||
shutil.rmtree(root)
|
||||
|
|
|
@ -2,23 +2,14 @@
|
|||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torchvision.transforms.functional as F
|
||||
from lightning.pytorch import Trainer
|
||||
from pytest import MonkeyPatch
|
||||
from torch import Tensor
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datamodules import LEVIRCDDataModule, LEVIRCDPlusDataModule
|
||||
from torchgeo.datasets import LEVIRCD, LEVIRCDPlus
|
||||
|
||||
|
||||
def download_url(url: str, root: str, *args: str) -> None:
|
||||
shutil.copy(url, root)
|
||||
|
||||
|
||||
def transforms(sample: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
|
@ -44,23 +35,10 @@ def transforms(sample: dict[str, Tensor]) -> dict[str, Tensor]:
|
|||
|
||||
class TestLEVIRCDPlusDataModule:
|
||||
@pytest.fixture
|
||||
def datamodule(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> LEVIRCDPlusDataModule:
|
||||
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
|
||||
md5 = '0ccca34310bfe7096dadfbf05b0d180f'
|
||||
monkeypatch.setattr(LEVIRCDPlus, 'md5', md5)
|
||||
url = os.path.join('tests', 'data', 'levircd', 'levircdplus', 'LEVIR-CD+.zip')
|
||||
monkeypatch.setattr(LEVIRCDPlus, 'url', url)
|
||||
|
||||
root = str(tmp_path)
|
||||
def datamodule(self) -> LEVIRCDPlusDataModule:
|
||||
root = os.path.join('tests', 'data', 'levircd', 'levircdplus')
|
||||
dm = LEVIRCDPlusDataModule(
|
||||
root=root,
|
||||
download=True,
|
||||
num_workers=0,
|
||||
checksum=True,
|
||||
val_split_pct=0.5,
|
||||
transforms=transforms,
|
||||
root=root, num_workers=0, val_split_pct=0.5, transforms=transforms
|
||||
)
|
||||
dm.prepare_data()
|
||||
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
|
||||
|
@ -113,36 +91,9 @@ class TestLEVIRCDPlusDataModule:
|
|||
|
||||
class TestLEVIRCDDataModule:
|
||||
@pytest.fixture
|
||||
def datamodule(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> LEVIRCDDataModule:
|
||||
directory = os.path.join('tests', 'data', 'levircd', 'levircd')
|
||||
splits = {
|
||||
'train': {
|
||||
'url': os.path.join(directory, 'train.zip'),
|
||||
'filename': 'train.zip',
|
||||
'md5': '7c2e24b3072095519f1be7eb01fae4ff',
|
||||
},
|
||||
'val': {
|
||||
'url': os.path.join(directory, 'val.zip'),
|
||||
'filename': 'val.zip',
|
||||
'md5': '5c320223ba88b6fc8ff9d1feebc3b84e',
|
||||
},
|
||||
'test': {
|
||||
'url': os.path.join(directory, 'test.zip'),
|
||||
'filename': 'test.zip',
|
||||
'md5': '021db72d4486726d6a0702563a617b32',
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr(torchgeo.datasets.utils, 'download_url', download_url)
|
||||
monkeypatch.setattr(LEVIRCD, 'splits', splits)
|
||||
|
||||
root = str(tmp_path)
|
||||
dm = LEVIRCDDataModule(
|
||||
root=root,
|
||||
download=True,
|
||||
num_workers=0,
|
||||
checksum=True,
|
||||
transforms=transforms,
|
||||
)
|
||||
def datamodule(self) -> LEVIRCDDataModule:
|
||||
root = os.path.join('tests', 'data', 'levircd', 'levircd')
|
||||
dm = LEVIRCDDataModule(root=root, num_workers=0, transforms=transforms)
|
||||
dm.prepare_data()
|
||||
dm.trainer = Trainer(accelerator='cpu', max_epochs=1)
|
||||
return dm
|
||||
|
|