зеркало из https://github.com/microsoft/torchgeo.git
SpaceNet: add SpaceNet 8, radiant mlhub -> aws (#2203)
This commit is contained in:
Родитель
17b5ccf5ff
Коммит
880593e7ef
|
@ -408,6 +408,7 @@ SpaceNet
|
|||
.. autoclass:: SpaceNet5
|
||||
.. autoclass:: SpaceNet6
|
||||
.. autoclass:: SpaceNet7
|
||||
.. autoclass:: SpaceNet8
|
||||
|
||||
SSL4EO
|
||||
^^^^^^
|
||||
|
|
|
@ -91,8 +91,6 @@ datasets = [
|
|||
"pycocotools>=2.0.7",
|
||||
# pyvista 0.34.2+ required to avoid ImportError in CI
|
||||
"pyvista>=0.34.2",
|
||||
# radiant-mlhub 0.3+ required for newer tqdm support required by lightning
|
||||
"radiant-mlhub>=0.3",
|
||||
# scikit-image 0.19+ required for Python 3.10 wheels
|
||||
"scikit-image>=0.19",
|
||||
# scipy 1.7.2+ required for Python 3.10 wheels
|
||||
|
|
|
@ -4,6 +4,5 @@ laspy==2.5.4
|
|||
opencv-python==4.10.0.84
|
||||
pycocotools==2.0.8
|
||||
pyvista==0.44.1
|
||||
radiant-mlhub==0.4.1
|
||||
scikit-image==0.24.0
|
||||
scipy==1.14.0
|
||||
|
|
|
@ -27,7 +27,6 @@ laspy==2.0.0
|
|||
opencv-python==4.5.4.58
|
||||
pycocotools==2.0.7
|
||||
pyvista==0.34.2
|
||||
radiant-mlhub==0.3.0
|
||||
scikit-image==0.19.0
|
||||
scipy==1.7.2
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ data:
|
|||
class_path: SpaceNet1DataModule
|
||||
init_args:
|
||||
batch_size: 1
|
||||
val_split_pct: 0.33
|
||||
test_split_pct: 0.33
|
||||
val_split_pct: 0.34
|
||||
test_split_pct: 0.34
|
||||
dict_kwargs:
|
||||
root: "tests/data/spacenet"
|
||||
|
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_3band.tar.gz
Normal file
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_3band.tar.gz
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_8band.tar.gz
Normal file
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_8band.tar.gz
Normal file
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz
Normal file
Двоичные данные
tests/data/spacenet/SN1_buildings/train/SN1_buildings_train_AOI_1_Rio_geojson_buildings.tar.gz
Normal file
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
{"type": "FeatureCollection", "crs": {"type": "name", "properties": {"name": "urn:ogc:def:crs:OGC:1.3:CRS84"}}, "features": [{"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [[[-43.7720361, -22.922229499999958, 0.0], [-43.772064, -22.9222724, 0.0], [-43.77210239999994, -22.922247399999947, 0.0], [-43.772074499999974, -22.9222046, 0.0], [-43.7720361, -22.922229499999958, 0.0]]]}}]}
|
|
@ -3,279 +3,90 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from typing import cast
|
||||
|
||||
import fiona
|
||||
import numpy as np
|
||||
import rasterio
|
||||
from rasterio.crs import CRS
|
||||
from rasterio.transform import Affine
|
||||
from torchvision.datasets.utils import calculate_md5
|
||||
|
||||
from torchgeo.datasets import (
|
||||
SpaceNet,
|
||||
SpaceNet1,
|
||||
SpaceNet2,
|
||||
SpaceNet3,
|
||||
SpaceNet4,
|
||||
SpaceNet5,
|
||||
SpaceNet6,
|
||||
SpaceNet7,
|
||||
)
|
||||
SIZE = 2
|
||||
|
||||
transform = Affine(0.3, 0.0, 616500.0, 0.0, -0.3, 3345000.0)
|
||||
crs = CRS.from_epsg(4326)
|
||||
dataset_id = 'SN1_buildings'
|
||||
|
||||
img_count = {
|
||||
'MS.tif': 8,
|
||||
'PAN.tif': 1,
|
||||
'PS-MS.tif': 8,
|
||||
'PS-RGB.tif': 3,
|
||||
'PS-RGBNIR.tif': 4,
|
||||
'RGB.tif': 3,
|
||||
'RGBNIR.tif': 4,
|
||||
'SAR-Intensity.tif': 1,
|
||||
'mosaic.tif': 3,
|
||||
'8Band.tif': 8,
|
||||
profile = {
|
||||
'driver': 'GTiff',
|
||||
'dtype': 'uint8',
|
||||
'width': SIZE,
|
||||
'height': SIZE,
|
||||
'crs': CRS.from_epsg(4326),
|
||||
'transform': Affine(
|
||||
4.489235388119662e-06,
|
||||
0.0,
|
||||
-43.7732462563,
|
||||
0.0,
|
||||
-4.486127586210932e-06,
|
||||
-22.9214851954,
|
||||
),
|
||||
}
|
||||
|
||||
np.random.seed(0)
|
||||
Z = np.random.randint(np.iinfo('uint8').max, size=(SIZE, SIZE), dtype='uint8')
|
||||
|
||||
sn4_catalog = [
|
||||
'10300100023BC100',
|
||||
'10300100036D5200',
|
||||
'1030010003BDDC00',
|
||||
'1030010003CD4300',
|
||||
]
|
||||
sn4_angles = [8, 30, 52, 53]
|
||||
|
||||
sn4_imgdirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-nadir{}_catid_{}'
|
||||
sn4_lbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3730989-labels'
|
||||
sn4_emptyimgdirname = (
|
||||
'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-nadir53_'
|
||||
+ 'catid_1030010003CD4300'
|
||||
)
|
||||
sn4_emptylbldirname = 'sn4_SN4_buildings_train_AOI_6_Atlanta_732701_3720639-labels'
|
||||
|
||||
|
||||
datasets = [SpaceNet1, SpaceNet2, SpaceNet3, SpaceNet4, SpaceNet5, SpaceNet6, SpaceNet7]
|
||||
|
||||
|
||||
def create_test_image(img_dir: str, imgs: list[str]) -> list[list[float]]:
|
||||
"""Create test image
|
||||
|
||||
Args:
|
||||
img_dir (str): Name of image directory
|
||||
imgs (List[str]): List of images to be created
|
||||
|
||||
Returns:
|
||||
List[List[float]]: Boundary coordinates
|
||||
"""
|
||||
for img in imgs:
|
||||
imgpath = os.path.join(img_dir, img)
|
||||
Z = np.arange(4, dtype='uint16').reshape(2, 2)
|
||||
count = img_count[img]
|
||||
with rasterio.open(
|
||||
imgpath,
|
||||
'w',
|
||||
driver='GTiff',
|
||||
height=Z.shape[0],
|
||||
width=Z.shape[1],
|
||||
count=count,
|
||||
dtype=Z.dtype,
|
||||
crs=crs,
|
||||
transform=transform,
|
||||
) as dst:
|
||||
for i in range(1, dst.count + 1):
|
||||
dst.write(Z, i)
|
||||
|
||||
tim = rasterio.open(imgpath)
|
||||
slice_index = [[1, 1], [1, 2], [2, 2], [2, 1], [1, 1]]
|
||||
return [list(tim.transform * p) for p in slice_index]
|
||||
|
||||
|
||||
def create_test_label(
|
||||
lbldir: str,
|
||||
lblname: str,
|
||||
coords: list[list[float]],
|
||||
det_type: str,
|
||||
empty: bool = False,
|
||||
diff_crs: bool = False,
|
||||
) -> None:
|
||||
"""Create test label
|
||||
|
||||
Args:
|
||||
lbldir (str): Name of label directory
|
||||
lblname (str): Name of label file
|
||||
coords (List[Tuple[float, float]]): Boundary coordinates
|
||||
det_type (str): Type of dataset. Must be either buildings or roads.
|
||||
empty (bool, optional): Creates empty label file if True. Defaults to False.
|
||||
diff_crs (bool, optional): Assigns EPSG:3857 as CRS instead of
|
||||
default EPSG:4326. Defaults to False.
|
||||
"""
|
||||
if empty:
|
||||
# Creates a new file
|
||||
with open(os.path.join(lbldir, lblname), 'w'):
|
||||
pass
|
||||
return
|
||||
|
||||
if det_type == 'buildings':
|
||||
meta_properties = OrderedDict()
|
||||
geom = 'Polygon'
|
||||
rec = {
|
||||
'type': 'Feature',
|
||||
'id': '0',
|
||||
'properties': OrderedDict(),
|
||||
'geometry': {'type': 'Polygon', 'coordinates': [coords]},
|
||||
}
|
||||
else:
|
||||
meta_properties = OrderedDict(
|
||||
[
|
||||
('heading', 'str'),
|
||||
('lane_number', 'str'),
|
||||
('one_way_ty', 'str'),
|
||||
('paved', 'str'),
|
||||
('road_id', 'int'),
|
||||
('road_type', 'str'),
|
||||
('origarea', 'int'),
|
||||
('origlen', 'float'),
|
||||
('partialDec', 'int'),
|
||||
('truncated', 'int'),
|
||||
('bridge_type', 'str'),
|
||||
('inferred_speed_mph', 'float'),
|
||||
('inferred_speed_mps', 'float'),
|
||||
]
|
||||
for count in [3, 8]:
|
||||
os.makedirs(os.path.join(dataset_id, 'train', f'{count}band'), exist_ok=True)
|
||||
for i in range(1, 5):
|
||||
path = os.path.join(
|
||||
dataset_id, 'train', f'{count}band', f'3band_AOI_1_RIO_img{i}.tif'
|
||||
)
|
||||
geom = 'LineString'
|
||||
profile['count'] = count
|
||||
with rasterio.open(path, 'w', **profile) as src:
|
||||
for j in range(1, count + 1):
|
||||
src.write(Z, j)
|
||||
|
||||
dummy_vals = {'str': 'a', 'float': 45.0, 'int': 0}
|
||||
ROAD_DICT = [(k, dummy_vals[v]) for k, v in meta_properties.items()]
|
||||
rec = {
|
||||
shutil.make_archive(
|
||||
os.path.join(dataset_id, 'train', f'SN1_buildings_train_AOI_1_Rio_{count}band'),
|
||||
'gztar',
|
||||
os.path.join(dataset_id, 'train'),
|
||||
f'{count}band',
|
||||
)
|
||||
|
||||
geojson = {
|
||||
'type': 'FeatureCollection',
|
||||
'crs': {'type': 'name', 'properties': {'name': 'urn:ogc:def:crs:OGC:1.3:CRS84'}},
|
||||
'features': [
|
||||
{
|
||||
'type': 'Feature',
|
||||
'id': '0',
|
||||
'properties': OrderedDict(ROAD_DICT),
|
||||
'geometry': {'type': 'LineString', 'coordinates': [coords[0], coords[2]]},
|
||||
'geometry': {
|
||||
'type': 'Polygon',
|
||||
'coordinates': [
|
||||
[
|
||||
[-43.7720361, -22.922229499999958, 0.0],
|
||||
[-43.772064, -22.9222724, 0.0],
|
||||
[-43.772102399999937, -22.922247399999947, 0.0],
|
||||
[-43.772074499999974, -22.9222046, 0.0],
|
||||
[-43.7720361, -22.922229499999958, 0.0],
|
||||
]
|
||||
],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
meta = {
|
||||
'driver': 'GeoJSON',
|
||||
'schema': {'properties': meta_properties, 'geometry': geom},
|
||||
'crs': {'init': 'epsg:4326'},
|
||||
}
|
||||
if diff_crs:
|
||||
meta['crs'] = {'init': 'epsg:3857'}
|
||||
out_file = os.path.join(lbldir, lblname)
|
||||
with fiona.open(out_file, 'w', **meta) as dst:
|
||||
dst.write(rec)
|
||||
os.makedirs(os.path.join(dataset_id, 'train', 'geojson'), exist_ok=True)
|
||||
for i in range(1, 4):
|
||||
path = os.path.join(dataset_id, 'train', 'geojson', f'Geo_AOI_1_RIO_img{i}.geojson')
|
||||
with open(path, 'w') as src:
|
||||
if i % 2 == 0:
|
||||
json.dump(geojson, src)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ROOT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
for dataset in datasets:
|
||||
collections = list(dataset.collection_md5_dict.keys())
|
||||
for collection in collections:
|
||||
dataset = cast(SpaceNet, dataset)
|
||||
if dataset.dataset_id == 'spacenet4':
|
||||
num_samples = 4
|
||||
elif collection == 'sn5_AOI_7_Moscow' or collection not in [
|
||||
'sn5_AOI_8_Mumbai',
|
||||
'sn7_test_source',
|
||||
]:
|
||||
num_samples = 3
|
||||
elif collection == 'sn5_AOI_8_Mumbai':
|
||||
num_samples = 3
|
||||
else:
|
||||
num_samples = 1
|
||||
|
||||
for sample in range(num_samples):
|
||||
out_dir = os.path.join(ROOT_DIR, collection)
|
||||
if collection == 'sn6_AOI_11_Rotterdam':
|
||||
out_dir = os.path.join(ROOT_DIR, 'spacenet6', collection)
|
||||
|
||||
# Create img dir
|
||||
if dataset.dataset_id == 'spacenet4':
|
||||
assert num_samples == 4
|
||||
if sample != 3:
|
||||
imgdirname = sn4_imgdirname.format(
|
||||
sn4_angles[sample], sn4_catalog[sample]
|
||||
)
|
||||
lbldirname = sn4_lbldirname
|
||||
else:
|
||||
imgdirname = sn4_emptyimgdirname.format(
|
||||
sn4_angles[sample], sn4_catalog[sample]
|
||||
)
|
||||
lbldirname = sn4_emptylbldirname
|
||||
else:
|
||||
imgdirname = f'{collection}_img{sample + 1}'
|
||||
lbldirname = f'{collection}_img{sample + 1}-labels'
|
||||
|
||||
imgdir = os.path.join(out_dir, imgdirname)
|
||||
os.makedirs(imgdir, exist_ok=True)
|
||||
bounds = create_test_image(imgdir, list(dataset.imagery.values()))
|
||||
|
||||
# Create lbl dir
|
||||
lbldir = os.path.join(out_dir, lbldirname)
|
||||
os.makedirs(lbldir, exist_ok=True)
|
||||
det_type = 'roads' if dataset in [SpaceNet3, SpaceNet5] else 'buildings'
|
||||
if dataset.dataset_id == 'spacenet4' and sample == 3:
|
||||
# Creates an empty file
|
||||
create_test_label(
|
||||
lbldir, dataset.label_glob, bounds, det_type, empty=True
|
||||
)
|
||||
else:
|
||||
create_test_label(lbldir, dataset.label_glob, bounds, det_type)
|
||||
|
||||
if collection == 'sn5_AOI_8_Mumbai':
|
||||
if sample == 1:
|
||||
create_test_label(
|
||||
lbldir, dataset.label_glob, bounds, det_type, empty=True
|
||||
)
|
||||
if sample == 2:
|
||||
create_test_label(
|
||||
lbldir, dataset.label_glob, bounds, det_type, diff_crs=True
|
||||
)
|
||||
|
||||
if collection == 'sn1_AOI_1_RIO' and sample == 1:
|
||||
create_test_label(
|
||||
lbldir, dataset.label_glob, bounds, det_type, diff_crs=True
|
||||
)
|
||||
|
||||
if collection not in [
|
||||
'sn2_AOI_2_Vegas',
|
||||
'sn3_AOI_5_Khartoum',
|
||||
'sn4_AOI_6_Atlanta',
|
||||
'sn5_AOI_8_Mumbai',
|
||||
'sn6_AOI_11_Rotterdam',
|
||||
'sn7_train_source',
|
||||
]:
|
||||
# Create collection.json
|
||||
with open(
|
||||
os.path.join(ROOT_DIR, collection, 'collection.json'), 'w'
|
||||
):
|
||||
pass
|
||||
if collection == 'sn6_AOI_11_Rotterdam':
|
||||
# Create collection.json
|
||||
with open(
|
||||
os.path.join(
|
||||
ROOT_DIR, 'spacenet6', collection, 'collection.json'
|
||||
),
|
||||
'w',
|
||||
):
|
||||
pass
|
||||
|
||||
# Create archive
|
||||
if collection == 'sn6_AOI_11_Rotterdam':
|
||||
break
|
||||
archive_path = os.path.join(ROOT_DIR, collection)
|
||||
shutil.make_archive(
|
||||
archive_path, 'gztar', root_dir=ROOT_DIR, base_dir=collection
|
||||
)
|
||||
shutil.rmtree(out_dir)
|
||||
print(f'{collection}: {calculate_md5(f"{archive_path}.tar.gz")}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
shutil.make_archive(
|
||||
os.path.join(
|
||||
dataset_id, 'train', 'SN1_buildings_train_AOI_1_Rio_geojson_buildings'
|
||||
),
|
||||
'gztar',
|
||||
os.path.join(dataset_id, 'train'),
|
||||
'geojson',
|
||||
)
|
||||
|
|
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO.tar.gz
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO.tar.gz
Двоичный файл не отображается.
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/8Band.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/8Band.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/RGB.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img1/RGB.tif
Двоичный файл не отображается.
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:EPSG::3857" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/8Band.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/8Band.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/RGB.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img2/RGB.tif
Двоичный файл не отображается.
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/8Band.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/8Band.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/RGB.tif
Двоичные данные
tests/data/spacenet/sn1_AOI_1_RIO/sn1_AOI_1_RIO_img3/RGB.tif
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz
Двоичные данные
tests/data/spacenet/sn2_AOI_2_Vegas.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn2_AOI_3_Paris.tar.gz
Двоичные данные
tests/data/spacenet/sn2_AOI_3_Paris.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz
Двоичные данные
tests/data/spacenet/sn2_AOI_4_Shanghai.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz
Двоичные данные
tests/data/spacenet/sn2_AOI_5_Khartoum.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz
Двоичные данные
tests/data/spacenet/sn3_AOI_2_Vegas.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn3_AOI_3_Paris.tar.gz
Двоичные данные
tests/data/spacenet/sn3_AOI_3_Paris.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz
Двоичные данные
tests/data/spacenet/sn3_AOI_4_Shanghai.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz
Двоичные данные
tests/data/spacenet/sn3_AOI_5_Khartoum.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz
Двоичные данные
tests/data/spacenet/sn4_AOI_6_Atlanta.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz
Двоичные данные
tests/data/spacenet/sn5_AOI_7_Moscow.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz
Двоичные данные
tests/data/spacenet/sn5_AOI_8_Mumbai.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn7_test_source.tar.gz
Двоичные данные
tests/data/spacenet/sn7_test_source.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn7_train_labels.tar.gz
Двоичные данные
tests/data/spacenet/sn7_train_labels.tar.gz
Двоичный файл не отображается.
Двоичные данные
tests/data/spacenet/sn7_train_source.tar.gz
Двоичные данные
tests/data/spacenet/sn7_train_source.tar.gz
Двоичный файл не отображается.
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"type": "FeatureCollection",
|
||||
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
|
||||
"features": [
|
||||
{ "type": "Feature", "properties": { }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 616500.300000000046566, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.4 ], [ 616500.599999999976717, 3344999.4 ], [ 616500.599999999976717, 3344999.700000000186265 ], [ 616500.300000000046566, 3344999.700000000186265 ] ] ] } }
|
||||
]
|
||||
}
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
aws.py
|
|
@ -0,0 +1,6 @@
|
|||
REM Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
REM Licensed under the MIT License.
|
||||
|
||||
@ECHO OFF
|
||||
|
||||
python3 tests\datasets\aws.py %*
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Basic mock-up of the AWS CLI."""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
s3 = subparsers.add_parser('s3')
|
||||
subsubparsers = s3.add_subparsers()
|
||||
cp = subsubparsers.add_parser('cp')
|
||||
cp.add_argument('source')
|
||||
cp.add_argument('destination')
|
||||
args, _ = parser.parse_known_args()
|
||||
shutil.copy(args.source, args.destination)
|
|
@ -31,6 +31,13 @@ def download_url(monkeypatch: MonkeyPatch, request: SubRequest) -> None:
|
|||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aws(monkeypatch: MonkeyPatch) -> Executable:
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
monkeypatch.setenv('PATH', path, prepend=os.pathsep)
|
||||
return which('aws')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def azcopy(monkeypatch: MonkeyPatch) -> Executable:
|
||||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
@ -10,437 +9,61 @@ import matplotlib.pyplot as plt
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
from torchgeo.datasets import (
|
||||
DatasetNotFoundError,
|
||||
SpaceNet1,
|
||||
SpaceNet2,
|
||||
SpaceNet3,
|
||||
SpaceNet4,
|
||||
SpaceNet5,
|
||||
SpaceNet6,
|
||||
SpaceNet7,
|
||||
)
|
||||
|
||||
TEST_DATA_DIR = 'tests/data/spacenet'
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
from torchgeo.datasets import DatasetNotFoundError, SpaceNet1
|
||||
from torchgeo.datasets.utils import Executable
|
||||
|
||||
|
||||
class Collection:
|
||||
def __init__(self, collection_id: str) -> None:
|
||||
self.collection_id = collection_id
|
||||
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(TEST_DATA_DIR, '*.tar.gz')
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, dataset_id: str) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(TEST_DATA_DIR, 'spacenet*')
|
||||
for directory in glob.iglob(glob_path):
|
||||
dataset_name = os.path.basename(directory)
|
||||
output_dir = os.path.join(output_dir, dataset_name)
|
||||
shutil.copytree(directory, output_dir)
|
||||
|
||||
|
||||
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
||||
return Collection(collection_id)
|
||||
|
||||
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> Dataset:
|
||||
return Dataset(dataset_id)
|
||||
|
||||
|
||||
class TestSpaceNet1:
|
||||
@pytest.fixture(params=['rgb', '8band'])
|
||||
class TestSpaceNet:
|
||||
@pytest.fixture
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
self, aws: Executable, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet1:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {'sn1_AOI_1_RIO': '127a523561987110f008e8c9815ce807'}
|
||||
|
||||
# Refer https://github.com/python/mypy/issues/1032
|
||||
monkeypatch.setattr(SpaceNet1, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet1(
|
||||
root, image=request.param, transforms=transforms, download=True, api_key=''
|
||||
url = os.path.join(
|
||||
'tests', 'data', 'spacenet', '{dataset_id}', 'train', '{tarball}'
|
||||
)
|
||||
monkeypatch.setattr(SpaceNet1, 'url', url)
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet1(tmp_path, transforms=transforms, download=True)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet1) -> None:
|
||||
x = dataset[0]
|
||||
dataset[1]
|
||||
@pytest.mark.parametrize('index', [0, 1])
|
||||
def test_getitem(self, dataset: SpaceNet1, index: int) -> None:
|
||||
x = dataset[index]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'rgb':
|
||||
assert x['image'].shape[0] == 3
|
||||
else:
|
||||
assert x['image'].shape[0] == 8
|
||||
|
||||
def test_len(self, dataset: SpaceNet1) -> None:
|
||||
assert len(dataset) == 3
|
||||
|
||||
def test_already_extracted(self, dataset: SpaceNet1) -> None:
|
||||
SpaceNet1(root=dataset.root)
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet1) -> None:
|
||||
SpaceNet1(root=dataset.root, download=True)
|
||||
for product in ['3band', '8band', 'geojson']:
|
||||
dir = os.path.join(dataset.root, dataset.dataset_id, dataset.split, product)
|
||||
shutil.rmtree(dir)
|
||||
SpaceNet1(root=dataset.root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet1(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet1) -> None:
|
||||
x = dataset[0].copy()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet2:
|
||||
@pytest.fixture(params=['PAN', 'MS', 'PS-MS', 'PS-RGB'])
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet2:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {
|
||||
'sn2_AOI_2_Vegas': '131048686ba21a45853c05f227f40b7f',
|
||||
'sn2_AOI_3_Paris': '62242fd198ee32b59f0178cf656e1513',
|
||||
'sn2_AOI_4_Shanghai': '563b0817ecedd8ff3b3e4cb2991bf3fb',
|
||||
'sn2_AOI_5_Khartoum': 'e4185a2e9a12cf7b3d0cd1db6b3e0f06',
|
||||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet2, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet2(
|
||||
root,
|
||||
image=request.param,
|
||||
collections=['sn2_AOI_2_Vegas', 'sn2_AOI_5_Khartoum'],
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key='',
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet2) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'PS-RGB':
|
||||
assert x['image'].shape[0] == 3
|
||||
elif dataset.image in ['MS', 'PS-MS']:
|
||||
assert x['image'].shape[0] == 8
|
||||
else:
|
||||
assert x['image'].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet2) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet2) -> None:
|
||||
SpaceNet2(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet2(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet2) -> None:
|
||||
dataset.collection_md5_dict['sn2_AOI_2_Vegas'] = 'randommd5hash123'
|
||||
with pytest.raises(RuntimeError, match='Collection sn2_AOI_2_Vegas corrupted'):
|
||||
SpaceNet2(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet2) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
def test_image_id(self, monkeypatch: MonkeyPatch, dataset: SpaceNet1) -> None:
|
||||
file_regex = r'global_monthly_(\d+.*\d+)'
|
||||
monkeypatch.setattr(dataset, 'file_regex', file_regex)
|
||||
dataset._image_id('global_monthly_2018_01_mosaic_L15-0331E-1257N_1327_3160.tif')
|
||||
|
||||
class TestSpaceNet3:
|
||||
@pytest.fixture(params=zip(['PAN', 'MS'], [False, True]))
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet3:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {
|
||||
'sn3_AOI_3_Paris': '93452c68da11dd6b57dc83dba43c2c9d',
|
||||
'sn3_AOI_5_Khartoum': '7c9d96810198bf101cbaf54f7a5e8b3b',
|
||||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet3, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet3(
|
||||
root,
|
||||
image=request.param[0],
|
||||
speed_mask=request.param[1],
|
||||
collections=['sn3_AOI_3_Paris', 'sn3_AOI_5_Khartoum'],
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key='',
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet3) -> None:
|
||||
# Iterate over all elements to maximize coverage
|
||||
samples = [dataset[i] for i in range(len(dataset))]
|
||||
x = samples[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'MS':
|
||||
assert x['image'].shape[0] == 8
|
||||
else:
|
||||
assert x['image'].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet3) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet3) -> None:
|
||||
SpaceNet3(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet3(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet3) -> None:
|
||||
dataset.collection_md5_dict['sn3_AOI_5_Khartoum'] = 'randommd5hash123'
|
||||
with pytest.raises(
|
||||
RuntimeError, match='Collection sn3_AOI_5_Khartoum corrupted'
|
||||
):
|
||||
SpaceNet3(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet3) -> None:
|
||||
x = dataset[0].copy()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
dataset.plot({'image': x['image']})
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet4:
|
||||
@pytest.fixture(params=['PAN', 'MS', 'PS-RGBNIR'])
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet4:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {'sn4_AOI_6_Atlanta': '097a76a2319b7ba34dac1722862fc93b'}
|
||||
|
||||
test_angles = ['nadir', 'off-nadir', 'very-off-nadir']
|
||||
|
||||
monkeypatch.setattr(SpaceNet4, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet4(
|
||||
root,
|
||||
image=request.param,
|
||||
angles=test_angles,
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key='',
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet4) -> None:
|
||||
# Get image-label pair with empty label to
|
||||
# ensure coverage
|
||||
x = dataset[2]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'PS-RGBNIR':
|
||||
assert x['image'].shape[0] == 4
|
||||
elif dataset.image == 'MS':
|
||||
assert x['image'].shape[0] == 8
|
||||
else:
|
||||
assert x['image'].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet4) -> None:
|
||||
assert len(dataset) == 4
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
|
||||
SpaceNet4(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet4(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict['sn4_AOI_6_Atlanta'] = 'randommd5hash123'
|
||||
with pytest.raises(
|
||||
RuntimeError, match='Collection sn4_AOI_6_Atlanta corrupted'
|
||||
):
|
||||
SpaceNet4(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet4) -> None:
|
||||
x = dataset[0].copy()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet5:
|
||||
@pytest.fixture(params=zip(['PAN', 'MS'], [False, True]))
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet5:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {
|
||||
'sn5_AOI_7_Moscow': '5c511dd31eea739cc1f81ef5962f3d56',
|
||||
'sn5_AOI_8_Mumbai': 'e00452b87bbe87feaef65f373be3978e',
|
||||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet5, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet5(
|
||||
root,
|
||||
image=request.param[0],
|
||||
speed_mask=request.param[1],
|
||||
collections=['sn5_AOI_7_Moscow', 'sn5_AOI_8_Mumbai'],
|
||||
transforms=transforms,
|
||||
download=True,
|
||||
api_key='',
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet5) -> None:
|
||||
# Iterate over all elements to maximize coverage
|
||||
samples = [dataset[i] for i in range(len(dataset))]
|
||||
x = samples[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'MS':
|
||||
assert x['image'].shape[0] == 8
|
||||
else:
|
||||
assert x['image'].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet5) -> None:
|
||||
assert len(dataset) == 5
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet5) -> None:
|
||||
SpaceNet5(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet5(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet5) -> None:
|
||||
dataset.collection_md5_dict['sn5_AOI_8_Mumbai'] = 'randommd5hash123'
|
||||
with pytest.raises(RuntimeError, match='Collection sn5_AOI_8_Mumbai corrupted'):
|
||||
SpaceNet5(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet5) -> None:
|
||||
x = dataset[0].copy()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
dataset.plot({'image': x['image']})
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet6:
|
||||
@pytest.fixture(params=['PAN', 'RGBNIR', 'PS-RGB', 'PS-RGBNIR', 'SAR-Intensity'])
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet6:
|
||||
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet6(
|
||||
root, image=request.param, transforms=transforms, download=True, api_key=''
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet6) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
if dataset.image == 'PS-RGB':
|
||||
assert x['image'].shape[0] == 3
|
||||
elif dataset.image in ['RGBNIR', 'PS-RGBNIR']:
|
||||
assert x['image'].shape[0] == 4
|
||||
else:
|
||||
assert x['image'].shape[0] == 1
|
||||
|
||||
def test_len(self, dataset: SpaceNet6) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet6) -> None:
|
||||
SpaceNet6(root=dataset.root, download=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet6) -> None:
|
||||
x = dataset[0].copy()
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
|
||||
|
||||
class TestSpaceNet7:
|
||||
@pytest.fixture(params=['train', 'test'])
|
||||
def dataset(
|
||||
self, request: SubRequest, monkeypatch: MonkeyPatch, tmp_path: Path
|
||||
) -> SpaceNet7:
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
test_md5 = {
|
||||
'sn7_train_source': '197bfa8842a40b09b6837b824a6370e0',
|
||||
'sn7_train_labels': '625ad8a989a5105bc766a53e53df4d0e',
|
||||
'sn7_test_source': '461f59eb21bb4f416c867f5037dfceeb',
|
||||
}
|
||||
|
||||
monkeypatch.setattr(SpaceNet7, 'collection_md5_dict', test_md5)
|
||||
root = tmp_path
|
||||
transforms = nn.Identity()
|
||||
return SpaceNet7(
|
||||
root, split=request.param, transforms=transforms, download=True, api_key=''
|
||||
)
|
||||
|
||||
def test_getitem(self, dataset: SpaceNet7) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
if dataset.split == 'train':
|
||||
assert isinstance(x['mask'], torch.Tensor)
|
||||
|
||||
def test_len(self, dataset: SpaceNet7) -> None:
|
||||
if dataset.split == 'train':
|
||||
assert len(dataset) == 2
|
||||
else:
|
||||
assert len(dataset) == 1
|
||||
|
||||
def test_already_downloaded(self, dataset: SpaceNet4) -> None:
|
||||
SpaceNet7(root=dataset.root, download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SpaceNet7(tmp_path)
|
||||
|
||||
def test_collection_checksum(self, dataset: SpaceNet4) -> None:
|
||||
dataset.collection_md5_dict['sn7_train_source'] = 'randommd5hash123'
|
||||
with pytest.raises(RuntimeError, match='Collection sn7_train_source corrupted'):
|
||||
SpaceNet7(root=dataset.root, download=True, checksum=True)
|
||||
|
||||
def test_plot(self, dataset: SpaceNet7) -> None:
|
||||
x = dataset[0].copy()
|
||||
if dataset.split == 'train':
|
||||
x['prediction'] = x['mask']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
||||
dataset.plot(x, show_titles=False)
|
||||
plt.close()
|
||||
def test_list_files(self, monkeypatch: MonkeyPatch, dataset: SpaceNet1) -> None:
|
||||
directory_glob = os.path.join('**', 'AOI_{aoi}_*', '{product}')
|
||||
monkeypatch.setattr(dataset, 'directory_glob', directory_glob)
|
||||
dataset._list_files(aoi=1)
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
@ -15,7 +13,6 @@ from typing import Any
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytest import MonkeyPatch
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from torchgeo.datasets import BoundingBox, DependencyNotFoundError
|
||||
|
@ -24,8 +21,6 @@ from torchgeo.datasets.utils import (
|
|||
array_to_tensor,
|
||||
concat_samples,
|
||||
disambiguate_timestamp,
|
||||
download_radiant_mlhub_collection,
|
||||
download_radiant_mlhub_dataset,
|
||||
lazy_import,
|
||||
merge_samples,
|
||||
percentile_normalization,
|
||||
|
@ -36,48 +31,6 @@ from torchgeo.datasets.utils import (
|
|||
)
|
||||
|
||||
|
||||
class MLHubDataset:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
|
||||
)
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
glob_path = os.path.join(
|
||||
'tests', 'data', 'ref_african_crops_kenya_02', '*.tar.gz'
|
||||
)
|
||||
for tarball in glob.iglob(glob_path):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch_dataset(dataset_id: str, **kwargs: str) -> MLHubDataset:
|
||||
return MLHubDataset()
|
||||
|
||||
|
||||
def fetch_collection(collection_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
|
||||
|
||||
def test_download_radiant_mlhub_dataset(
|
||||
tmp_path: Path, monkeypatch: MonkeyPatch
|
||||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Dataset, 'fetch', fetch_dataset)
|
||||
download_radiant_mlhub_dataset('', tmp_path)
|
||||
|
||||
|
||||
def test_download_radiant_mlhub_collection(
|
||||
tmp_path: Path, monkeypatch: MonkeyPatch
|
||||
) -> None:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch_collection)
|
||||
download_radiant_mlhub_collection('', tmp_path)
|
||||
|
||||
|
||||
class TestBoundingBox:
|
||||
def test_repr_str(self) -> None:
|
||||
bbox = BoundingBox(0, 1, 2.0, 3.0, -5, -4)
|
||||
|
|
|
@ -111,6 +111,7 @@ from .spacenet import (
|
|||
SpaceNet5,
|
||||
SpaceNet6,
|
||||
SpaceNet7,
|
||||
SpaceNet8,
|
||||
)
|
||||
from .splits import (
|
||||
random_bbox_assignment,
|
||||
|
@ -244,6 +245,7 @@ __all__ = (
|
|||
'SpaceNet5',
|
||||
'SpaceNet6',
|
||||
'SpaceNet7',
|
||||
'SpaceNet8',
|
||||
'SSL4EO',
|
||||
'SSL4EOLBenchmark',
|
||||
'SSL4EOL',
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -45,46 +45,6 @@ __all__ = (
|
|||
Path: TypeAlias = str | pathlib.Path
|
||||
|
||||
|
||||
def download_radiant_mlhub_dataset(
|
||||
dataset_id: str, download_root: Path, api_key: str | None = None
|
||||
) -> None:
|
||||
"""Download a dataset from Radiant Earth.
|
||||
|
||||
Args:
|
||||
dataset_id: the ID of the dataset to fetch
|
||||
download_root: directory to download to
|
||||
api_key: the API key to use for all requests from the session. Can also be
|
||||
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
|
||||
``~/.mlhub/profiles``.
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If radiant_mlhub is not installed.
|
||||
"""
|
||||
radiant_mlhub = lazy_import('radiant_mlhub')
|
||||
dataset = radiant_mlhub.Dataset.fetch(dataset_id, api_key=api_key)
|
||||
dataset.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
||||
def download_radiant_mlhub_collection(
|
||||
collection_id: str, download_root: Path, api_key: str | None = None
|
||||
) -> None:
|
||||
"""Download a collection from Radiant Earth.
|
||||
|
||||
Args:
|
||||
collection_id: the ID of the collection to fetch
|
||||
download_root: directory to download to
|
||||
api_key: the API key to use for all requests from the session. Can also be
|
||||
passed in via the ``MLHUB_API_KEY`` environment variable, or configured in
|
||||
``~/.mlhub/profiles``.
|
||||
|
||||
Raises:
|
||||
DependencyNotFoundError: If radiant_mlhub is not installed.
|
||||
"""
|
||||
radiant_mlhub = lazy_import('radiant_mlhub')
|
||||
collection = radiant_mlhub.Collection.fetch(collection_id, api_key=api_key)
|
||||
collection.download(output_dir=download_root, api_key=api_key)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BoundingBox:
|
||||
"""Data class for indexing spatiotemporal data."""
|
||||
|
|
Загрузка…
Ссылка в новой задаче