SpaceNet: add SpaceNet 8, radiant mlhub -> aws (#2203)

This commit is contained in:
Adam J. Stewart 2024-08-17 20:49:48 +02:00 коммит произвёл GitHub
Родитель 17b5ccf5ff
Коммит 880593e7ef
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
65 изменённых файлов: 612 добавлений и 1597 удалений

Просмотреть файл

@ -408,6 +408,7 @@ SpaceNet
.. autoclass:: SpaceNet5
.. autoclass:: SpaceNet6
.. autoclass:: SpaceNet7
.. autoclass:: SpaceNet8
SSL4EO
^^^^^^

Просмотреть файл

@ -91,8 +91,6 @@ datasets = [
"pycocotools>=2.0.7",
# pyvista 0.34.2+ required to avoid ImportError in CI
"pyvista>=0.34.2",
# radiant-mlhub 0.3+ required for newer tqdm support required by lightning
"radiant-mlhub>=0.3",
# scikit-image 0.19+ required for Python 3.10 wheels
"scikit-image>=0.19",
# scipy 1.7.2+ required for Python 3.10 wheels

Просмотреть файл

@ -4,6 +4,5 @@ laspy==2.5.4
opencv-python==4.10.0.84
pycocotools==2.0.8
pyvista==0.44.1
radiant-mlhub==0.4.1
scikit-image==0.24.0
scipy==1.14.0

Просмотреть файл

@ -27,7 +27,6 @@ laspy==2.0.0
opencv-python==4.5.4.58
pycocotools==2.0.7
pyvista==0.34.2
radiant-mlhub==0.3.0
scikit-image==0.19.0
scipy==1.7.2

Просмотреть файл

@ -12,7 +12,7 @@ data:
class_path: SpaceNet1DataModule
init_args:
batch_size: 1
val_split_pct: 0.33
test_split_pct: 0.33
val_split_pct: 0.34
test_split_pct: 0.34
dict_kwargs:
root: "tests/data/spacenet"

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1 @@
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[-43.7720361, -22.922229499999958, 0.0], [-43.772064, -22.9222724, 0.0], [-43.77210239999994, -22.922247399999947, 0.0], [-43.772074499999974, -22.9222046, 0.0], [-43.7720361, -22.922229499999958, 0.0]]]}}]}

Просмотреть файл

@ -3,279 +3,90 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import json
import os
import shutil
from collections import OrderedDict
from typing import cast
import fiona
import numpy as np
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
from torchvision.datasets.utils import calculate_md5
from torchgeo.datasets import (
SpaceNet,
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet6,
SpaceNet7,
)
SIZE = 2
transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0)
crs = CRS.from_epsg(4326)
dataset_id = 'SN1_buildings'
img_count = {
'MS.tif': 8,
'PAN.tif': 1,
'PS-MS.tif': 8,
'PS-RGB.tif': 3,
'PS-RGBNIR.tif': 4,
'RGB.tif': 3,
'RGBNIR.tif': 4,
'SAR-Intensity.tif': 1,
'mosaic.tif': 3,
'8Band.tif': 8,
profile = {
'driver': 'GTiff',
'dtype': 'uint8',
'width': SIZE,
'height': SIZE,
'crs': CRS.from_epsg(4326),
'transform': Affine(
4.489235388119662e-06,
0.0,
-43.7732462563,
0.0,
-4.486127586210932e-06,
-22.9214851954,
),
}
np.random.seed(0)
Z = np.random.randint(np.iinfo('uint8').max, size=(SIZE, SIZE), dtype='uint8')
sn4_catalog = [
'10300100023BC100',
'10300100036D5200',
'1030010003BDDC00',
'1030010003CD4300',
]
sn4_angles = [8, 30, 52, 53]
sn4_imgdirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-nadir{}_catid_{}'
sn4_lbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-labels'
sn4_emptyimgdirname = (
'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-nadir53_'
+ 'catid_1030010003CD4300'
)
sn4_emptylbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-labels'
datasets = [SpaceNet1, SpaceNet2, SpaceNet3, SpaceNet4, SpaceNet5, SpaceNet6, SpaceNet7]
def create_test_image(img_dir: str, imgs: list[str]) -> list[list[float]]:
"""Create test image
Args:
img_dir (str): Name of image directory
imgs (List[str]): List of images to be created
Returns:
List[List[float]]: Boundary coordinates
"""
for img in imgs:
imgpath = os.path.join(img_dir, img)
Z = np.arange(4, dtype='uint16').reshape(2, 2)
count = img_count[img]
with rasterio.open(
imgpath,
'w',
driver='GTiff',
height=Z.shape[0],
width=Z.shape[1],
count=count,
dtype=Z.dtype,
crs=crs,
transform=transform,
) as dst:
for i in range(1, dst.count + 1):
dst.write(Z, i)
tim = rasterio.open(imgpath)
slice_index = [[1, 1], [1, 2], [2, 2], [2, 1], [1, 1]]
return [list(tim.transform * p) for p in slice_index]
def create_test_label(
lbldir: str,
lblname: str,
coords: list[list[float]],
det_type: str,
empty: bool = False,
diff_crs: bool = False,
) -> None:
"""Create test label
Args:
lbldir (str): Name of label directory
lblname (str): Name of label file
coords (List[Tuple[float, float]]): Boundary coordinates
det_type (str): Type of dataset. Must be either buildings or roads.
empty (bool, optional): Creates empty label file if True. Defaults to False.
diff_crs (bool, optional): Assigns EPSG:3857 as CRS instead of
default EPSG:4326. Defaults to False.
"""
if empty:
# Creates a new file
with open(os.path.join(lbldir, lblname), 'w'):
pass
return
if det_type == 'buildings':
meta_properties = OrderedDict()
geom = 'Polygon'
rec = {
'type': 'Feature',
'id': '0',
'properties': OrderedDict(),
'geometry': {'type': 'Polygon', 'coordinates': [coords]},
}
else:
meta_properties = OrderedDict(
[
('heading', 'str'),
('lane_number', 'str'),
('one_way_ty', 'str'),
('paved', 'str'),
('road_id', 'int'),
('road_type', 'str'),
('origarea', 'int'),
('origlen', 'float'),
('partialDec', 'int'),
('truncated', 'int'),
('bridge_type', 'str'),
('inferred_speed_mph', 'float'),
('inferred_speed_mps', 'float'),
]
for count in [3, 8]:
os.makedirs(os.path.join(dataset_id, 'train', f'{count}band'), exist_ok=True)
for i in range(1, 5):
path = os.path.join(
dataset_id, 'train', f'{count}band', f'3band_AOI_1_RIO_img{i}.tif'
)
geom = 'LineString'
profile['count'] = count
with rasterio.open(path, 'w', **profile) as src:
for j in range(1, count + 1):
src.write(Z, j)
dummy_vals = {'str': 'a', 'float': 45.0, 'int': 0}
ROAD_DICT = [(k, dummy_vals[v]) for k, v in meta_properties.items()]
rec = {
shutil.make_archive(
os.path.join(dataset_id, 'train', f'SN1_buildings_train_AOI_1_Rio_{count}band'),
'gztar',
os.path.join(dataset_id, 'train'),
f'{count}band',
)
geojson = {
'type': 'FeatureCollection',
'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}},
'features': [
{
'type': 'Feature',
'id': '0',
'properties': OrderedDict(ROAD_DICT),
'geometry': {'type': 'LineString', 'coordinates': [coords[0], coords[2]]},
'geometry': {
'type': 'Polygon',
'coordinates': [
[
[-43.7720361, -22.922229499999958, 0.0],
[-43.772064, -22.9222724, 0.0],
[-43.772102399999937, -22.922247399999947, 0.0],
[-43.772074499999974, -22.9222046, 0.0],
[-43.7720361, -22.922229499999958, 0.0],
]
],
},
}
],
}
meta = {
'driver': 'GeoJSON',
'schema': {'properties': meta_properties, 'geometry': geom},
'crs': {'init': 'epsg:4326'},
}
if diff_crs:
meta['crs'] = {'init': 'epsg:3857'}
out_file = os.path.join(lbldir, lblname)
with fiona.open(out_file, 'w', **meta) as dst:
dst.write(rec)
os.makedirs(os.path.join(dataset_id, 'train', 'geojson'), exist_ok=True)
for i in range(1, 4):
path = os.path.join(dataset_id, 'train', 'geojson', f'Geo_AOI_1_RIO_img{i}.geojson')
with open(path, 'w') as src:
if i % 2 == 0:
json.dump(geojson, src)
def main() -> None:
ROOT_DIR = os.path.dirname(os.path.realpath(__file__))
for dataset in datasets:
collections = list(dataset.collection_md5_dict.keys())
for collection in collections:
dataset = cast(SpaceNet, dataset)
if dataset.dataset_id == 'spacenet4':
num_samples = 4
elif collection == 'sn5_AOI_7_Moscow' or collection not in [
'sn5_AOI_8_Mumbai',
'sn7_test_source',
]:
num_samples = 3
elif collection == 'sn5_AOI_8_Mumbai':
num_samples = 3
else:
num_samples = 1
for sample in range(num_samples):
out_dir = os.path.join(ROOT_DIR, collection)
if collection == 'sn6_AOI_11_Rotterdam':
out_dir = os.path.join(ROOT_DIR, 'spacenet6', collection)
# Create img dir
if dataset.dataset_id == 'spacenet4':
assert num_samples == 4
if sample != 3:
imgdirname = sn4_imgdirname.format(
sn4_angles[sample], sn4_catalog[sample]
)
lbldirname = sn4_lbldirname
else:
imgdirname = sn4_emptyimgdirname.format(
sn4_angles[sample], sn4_catalog[sample]
)
lbldirname = sn4_emptylbldirname
else:
imgdirname = f'{collection}_img{sample + 1}'
lbldirname = f'{collection}_img{sample + 1}-labels'
imgdir = os.path.join(out_dir, imgdirname)
os.makedirs(imgdir, exist_ok=True)
bounds = create_test_image(imgdir, list(dataset.imagery.values()))
# Create lbl dir
lbldir = os.path.join(out_dir, lbldirname)
os.makedirs(lbldir, exist_ok=True)
det_type = 'roads' if dataset in [SpaceNet3, SpaceNet5] else 'buildings'
if dataset.dataset_id == 'spacenet4' and sample == 3:
# Creates an empty file
create_test_label(
lbldir, dataset.label_glob, bounds, det_type, empty=True
)
else:
create_test_label(lbldir, dataset.label_glob, bounds, det_type)
if collection == 'sn5_AOI_8_Mumbai':
if sample == 1:
create_test_label(
lbldir, dataset.label_glob, bounds, det_type, empty=True
)
if sample == 2:
create_test_label(
lbldir, dataset.label_glob, bounds, det_type, diff_crs=True
)
if collection == 'sn1_AOI_1_RIO' and sample == 1:
create_test_label(
lbldir, dataset.label_glob, bounds, det_type, diff_crs=True
)
if collection not in [
'sn2_AOI_2_Vegas',
'sn3_AOI_5_Khartoum',
'sn4_AOI_6_Atlanta',
'sn5_AOI_8_Mumbai',
'sn6_AOI_11_Rotterdam',
'sn7_train_source',
]:
# Create collection.json
with open(
os.path.join(ROOT_DIR, collection, 'collection.json'), 'w'
):
pass
if collection == 'sn6_AOI_11_Rotterdam':
# Create collection.json
with open(
os.path.join(
ROOT_DIR, 'spacenet6', collection, 'collection.json'
),
'w',
):
pass
# Create archive
if collection == 'sn6_AOI_11_Rotterdam':
break
archive_path = os.path.join(ROOT_DIR, collection)
shutil.make_archive(
archive_path, 'gztar', root_dir=ROOT_DIR, base_dir=collection
)
shutil.rmtree(out_dir)
print(f'{collection}: {calculate_md5(f"{archive_path}.tar.gz")}')
if __name__ == '__main__':
main()
shutil.make_archive(
os.path.join(
dataset_id, 'train', 'SN1_buildings_train_AOI_1_Rio_geojson_buildings'
),
'gztar',
os.path.join(dataset_id, 'train'),
'geojson',
)

Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO.tar.gz

Двоичный файл не отображается.

Просмотреть файл

@ -1,7 +0,0 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
]
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -1,7 +0,0 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:EPSG::3857" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
]
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -1,7 +0,0 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
]
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn2_AOI_3_Paris.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn3_AOI_3_Paris.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn7_test_source.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn7_train_labels.tar.gz

Двоичный файл не отображается.

Двоичные данные
tests/data/spacenet/sn7_train_source.tar.gz

Двоичный файл не отображается.

Просмотреть файл

@ -1,7 +0,0 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
]
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -1,7 +0,0 @@
{
"type": "FeatureCollection",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
]
}

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

1
tests/datasets/aws Symbolic link
Просмотреть файл

@ -0,0 +1 @@
aws.py

6
tests/datasets/aws.bat Normal file
Просмотреть файл

@ -0,0 +1,6 @@
REM Copyright (c) Microsoft Corporation. All rights reserved.
REM Licensed under the MIT License.
@ECHO OFF
python3 tests\datasets\aws.py %*

20
tests/datasets/aws.py Executable file
Просмотреть файл

@ -0,0 +1,20 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Basic mock-up of the AWS CLI."""
import argparse
import shutil
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()
s3 = subparsers.add_parser('s3')
subsubparsers = s3.add_subparsers()
cp = subsubparsers.add_parser('cp')
cp.add_argument('source')
cp.add_argument('destination')
args, _ = parser.parse_known_args()
shutil.copy(args.source, args.destination)

Просмотреть файл

@ -31,6 +31,13 @@ def download_url(monkeypatch: MonkeyPatch, request: SubRequest) -> None:
pass
@pytest.fixture
def aws(monkeypatch: MonkeyPatch) -> Executable:
path = os.path.dirname(os.path.realpath(__file__))
monkeypatch.setenv('PATH', path, prepend=os.pathsep)
return which('aws')
@pytest.fixture
def azcopy(monkeypatch: MonkeyPatch) -> Executable:
path = os.path.dirname(os.path.realpath(__file__))

Просмотреть файл

@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
@ -10,437 +9,61 @@ import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch
from torchgeo.datasets import (
DatasetNotFoundError,
SpaceNet1,
SpaceNet2,
SpaceNet3,
SpaceNet4,
SpaceNet5,
SpaceNet6,
SpaceNet7,
)
TEST_DATA_DIR = 'tests/data/spacenet'
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
from torchgeo.datasets import DatasetNotFoundError, SpaceNet1
from torchgeo.datasets.utils import Executable
class Collection:
def __init__(self, collection_id: str) -> None:
self.collection_id = collection_id
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(TEST_DATA_DIR, '*.tar.gz')
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
class Dataset:
def __init__(self, dataset_id: str) -> None:
self.dataset_id = dataset_id
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(TEST_DATA_DIR, 'spacenet*')
for directory in glob.iglob(glob_path):
dataset_name = os.path.basename(directory)
output_dir = os.path.join(output_dir, dataset_name)
shutil.copytree(directory, output_dir)
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
return Collection(collection_id)
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
return Dataset(dataset_id)
class TestSpaceNet1:
@pytest.fixture(params=['rgb', '8band'])
class TestSpaceNet:
@pytest.fixture
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet1:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {'sn1_AOI_1_RIO': '127a523561987110f008e8c9815ce807'}
# Refer https://github.com/python/mypy/issues/1032
monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet1(
root, image=request.param, transforms=transforms, download=True, api_key=''
url = os.path.join(
'tests', 'data', 'spacenet', '{dataset_id}', 'train', '{tarball}'
)
monkeypatch.setattr(SpaceNet1, 'url', url)
transforms = nn.Identity()
return SpaceNet1(tmp_path, transforms=transforms, download=True)
def test_getitem(self, dataset: SpaceNet1) -> None:
x = dataset[0]
dataset[1]
@pytest.mark.parametrize('index', [0, 1])
def test_getitem(self, dataset: SpaceNet1, index: int) -> None:
x = dataset[index]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'rgb':
assert x['image'].shape[0] == 3
else:
assert x['image'].shape[0] == 8
def test_len(self, dataset: SpaceNet1) -> None:
assert len(dataset) == 3
def test_already_extracted(self, dataset: SpaceNet1) -> None:
SpaceNet1(root=dataset.root)
def test_already_downloaded(self, dataset: SpaceNet1) -> None:
SpaceNet1(root=dataset.root, download=True)
for product in ['3band', '8band', 'geojson']:
dir = os.path.join(dataset.root, dataset.dataset_id, dataset.split, product)
shutil.rmtree(dir)
SpaceNet1(root=dataset.root)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet1(tmp_path)
def test_plot(self, dataset: SpaceNet1) -> None:
x = dataset[0].copy()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
class TestSpaceNet2:
@pytest.fixture(params=['PAN', 'MS', 'PS-MS', 'PS-RGB'])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet2:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {
'sn2_AOI_2_Vegas': '131048686ba21a45853c05f227f40b7f',
'sn2_AOI_3_Paris': '62242fd198ee32b59f0178cf656e1513',
'sn2_AOI_4_Shanghai': '563b0817ecedd8ff3b3e4cb2991bf3fb',
'sn2_AOI_5_Khartoum': 'e4185a2e9a12cf7b3d0cd1db6b3e0f06',
}
monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet2(
root,
image=request.param,
collections=['sn2_AOI_2_Vegas', 'sn2_AOI_5_Khartoum'],
transforms=transforms,
download=True,
api_key='',
)
def test_getitem(self, dataset: SpaceNet2) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'PS-RGB':
assert x['image'].shape[0] == 3
elif dataset.image in ['MS', 'PS-MS']:
assert x['image'].shape[0] == 8
else:
assert x['image'].shape[0] == 1
def test_len(self, dataset: SpaceNet2) -> None:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: SpaceNet2) -> None:
SpaceNet2(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet2(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123'
with pytest.raises(RuntimeError, match='Collection sn2_AOI_2_Vegas corrupted'):
SpaceNet2(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet2) -> None:
x = dataset[0].copy()
dataset.plot(x, show_titles=False)
plt.close()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
def test_image_id(self, monkeypatch: MonkeyPatch, dataset: SpaceNet1) -> None:
file_regex = r'global_monthly_(\d+.*\d+)'
monkeypatch.setattr(dataset, 'file_regex', file_regex)
dataset._image_id('global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160.tif')
class TestSpaceNet3:
@pytest.fixture(params=zip(['PAN', 'MS'], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet3:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {
'sn3_AOI_3_Paris': '93452c68da11dd6b57dc83dba43c2c9d',
'sn3_AOI_5_Khartoum': '7c9d96810198bf101cbaf54f7a5e8b3b',
}
monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet3(
root,
image=request.param[0],
speed_mask=request.param[1],
collections=['sn3_AOI_3_Paris', 'sn3_AOI_5_Khartoum'],
transforms=transforms,
download=True,
api_key='',
)
def test_getitem(self, dataset: SpaceNet3) -> None:
# Iterate over all elements to maximize coverage
samples = [dataset[i] for i in range(len(dataset))]
x = samples[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'MS':
assert x['image'].shape[0] == 8
else:
assert x['image'].shape[0] == 1
def test_len(self, dataset: SpaceNet3) -> None:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: SpaceNet3) -> None:
SpaceNet3(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet3(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123'
with pytest.raises(
RuntimeError, match='Collection sn3_AOI_5_Khartoum corrupted'
):
SpaceNet3(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet3) -> None:
x = dataset[0].copy()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
dataset.plot({'image': x['image']})
plt.close()
class TestSpaceNet4:
@pytest.fixture(params=['PAN', 'MS', 'PS-RGBNIR'])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet4:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {'sn4_AOI_6_Atlanta': '097a76a2319b7ba34dac1722862fc93b'}
test_angles = ['nadir', 'off-nadir', 'very-off-nadir']
monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet4(
root,
image=request.param,
angles=test_angles,
transforms=transforms,
download=True,
api_key='',
)
def test_getitem(self, dataset: SpaceNet4) -> None:
# Get image-label pair with empty label to
# ensure coverage
x = dataset[2]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'PS-RGBNIR':
assert x['image'].shape[0] == 4
elif dataset.image == 'MS':
assert x['image'].shape[0] == 8
else:
assert x['image'].shape[0] == 1
def test_len(self, dataset: SpaceNet4) -> None:
assert len(dataset) == 4
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet4(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet4(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123'
with pytest.raises(
RuntimeError, match='Collection sn4_AOI_6_Atlanta corrupted'
):
SpaceNet4(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet4) -> None:
x = dataset[0].copy()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
class TestSpaceNet5:
@pytest.fixture(params=zip(['PAN', 'MS'], [False, True]))
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet5:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {
'sn5_AOI_7_Moscow': '5c511dd31eea739cc1f81ef5962f3d56',
'sn5_AOI_8_Mumbai': 'e00452b87bbe87feaef65f373be3978e',
}
monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet5(
root,
image=request.param[0],
speed_mask=request.param[1],
collections=['sn5_AOI_7_Moscow', 'sn5_AOI_8_Mumbai'],
transforms=transforms,
download=True,
api_key='',
)
def test_getitem(self, dataset: SpaceNet5) -> None:
# Iterate over all elements to maximize coverage
samples = [dataset[i] for i in range(len(dataset))]
x = samples[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'MS':
assert x['image'].shape[0] == 8
else:
assert x['image'].shape[0] == 1
def test_len(self, dataset: SpaceNet5) -> None:
assert len(dataset) == 5
def test_already_downloaded(self, dataset: SpaceNet5) -> None:
SpaceNet5(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet5(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123'
with pytest.raises(RuntimeError, match='Collection sn5_AOI_8_Mumbai corrupted'):
SpaceNet5(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet5) -> None:
x = dataset[0].copy()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
dataset.plot({'image': x['image']})
plt.close()
class TestSpaceNet6:
@pytest.fixture(params=['PAN', 'RGBNIR', 'PS-RGB', 'PS-RGBNIR', 'SAR-Intensity'])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet6:
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
root = tmp_path
transforms = nn.Identity()
return SpaceNet6(
root, image=request.param, transforms=transforms, download=True, api_key=''
)
def test_getitem(self, dataset: SpaceNet6) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
if dataset.image == 'PS-RGB':
assert x['image'].shape[0] == 3
elif dataset.image in ['RGBNIR', 'PS-RGBNIR']:
assert x['image'].shape[0] == 4
else:
assert x['image'].shape[0] == 1
def test_len(self, dataset: SpaceNet6) -> None:
assert len(dataset) == 2
def test_already_downloaded(self, dataset: SpaceNet6) -> None:
SpaceNet6(root=dataset.root, download=True)
def test_plot(self, dataset: SpaceNet6) -> None:
x = dataset[0].copy()
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
class TestSpaceNet7:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
) -> SpaceNet7:
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
test_md5 = {
'sn7_train_source': '197bfa8842a40b09b6837b824a6370e0',
'sn7_train_labels': '625ad8a989a5105bc766a53e53df4d0e',
'sn7_test_source': '461f59eb21bb4f416c867f5037dfceeb',
}
monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5)
root = tmp_path
transforms = nn.Identity()
return SpaceNet7(
root, split=request.param, transforms=transforms, download=True, api_key=''
)
def test_getitem(self, dataset: SpaceNet7) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
if dataset.split == 'train':
assert isinstance(x['mask'], torch.Tensor)
def test_len(self, dataset: SpaceNet7) -> None:
if dataset.split == 'train':
assert len(dataset) == 2
else:
assert len(dataset) == 1
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
SpaceNet7(root=dataset.root, download=True)
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
SpaceNet7(tmp_path)
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123'
with pytest.raises(RuntimeError, match='Collection sn7_train_source corrupted'):
SpaceNet7(root=dataset.root, download=True, checksum=True)
def test_plot(self, dataset: SpaceNet7) -> None:
x = dataset[0].copy()
if dataset.split == 'train':
x['prediction'] = x['mask']
dataset.plot(x, suptitle='Test')
plt.close()
dataset.plot(x, show_titles=False)
plt.close()
def test_list_files(self, monkeypatch: MonkeyPatch, dataset: SpaceNet1) -> None:
directory_glob = os.path.join('**', 'AOI_{aoi}_*', '{product}')
monkeypatch.setattr(dataset, 'directory_glob', directory_glob)
dataset._list_files(aoi=1)

Просмотреть файл

@ -1,12 +1,10 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import math
import os
import pickle
import re
import shutil
import sys
from datetime import datetime
from pathlib import Path
@ -15,7 +13,6 @@ from typing import Any
import numpy as np
import pytest
import torch
from pytest import MonkeyPatch
from rasterio.crs import CRS
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
@ -24,8 +21,6 @@ from torchgeo.datasets.utils import (
array_to_tensor,
concat_samples,
disambiguate_timestamp,
download_radiant_mlhub_collection,
download_radiant_mlhub_dataset,
lazy_import,
merge_samples,
percentile_normalization,
@ -36,48 +31,6 @@ from torchgeo.datasets.utils import (
)
class MLHubDataset:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
)
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
glob_path = os.path.join(
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
)
for tarball in glob.iglob(glob_path):
shutil.copy(tarball, output_dir)
def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset:
return MLHubDataset()
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
return Collection()
def test_download_radiant_mlhub_dataset(
tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
download_radiant_mlhub_dataset('', tmp_path)
def test_download_radiant_mlhub_collection(
tmp_path: Path, monkeypatch: MonkeyPatch
) -> None:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
download_radiant_mlhub_collection('', tmp_path)
class TestBoundingBox:
def test_repr_str(self) -> None:
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)

Просмотреть файл

@ -111,6 +111,7 @@ from .spacenet import (
SpaceNet5,
SpaceNet6,
SpaceNet7,
SpaceNet8,
)
from .splits import (
random_bbox_assignment,
@ -244,6 +245,7 @@ __all__ = (
'SpaceNet5',
'SpaceNet6',
'SpaceNet7',
'SpaceNet8',
'SSL4EO',
'SSL4EOLBenchmark',
'SSL4EOL',

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -45,46 +45,6 @@ __all__ = (
Path: TypeAlias = str | pathlib.Path
def download_radiant_mlhub_dataset(
dataset_id: str, download_root: Path, api_key: str | None = None
) -> None:
"""Download a dataset from Radiant Earth.
Args:
dataset_id: the ID of the dataset to fetch
download_root: directory to download to
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
Raises:
DependencyNotFoundError: If radiant_mlhub is not installed.
"""
radiant_mlhub = lazy_import('radiant_mlhub')
dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key)
dataset.download(output_dir=download_root, api_key=api_key)
def download_radiant_mlhub_collection(
collection_id: str, download_root: Path, api_key: str | None = None
) -> None:
"""Download a collection from Radiant Earth.
Args:
collection_id: the ID of the collection to fetch
download_root: directory to download to
api_key: the API key to use for all requests from the session. Can also be
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
``~/.mlhub/profiles``.
Raises:
DependencyNotFoundError: If radiant_mlhub is not installed.
"""
radiant_mlhub = lazy_import('radiant_mlhub')
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
collection.download(output_dir=download_root, api_key=api_key)
@dataclass(frozen=True)
class BoundingBox:
"""Data class for indexing spatiotemporal data."""