зеркало из https://github.com/microsoft/torchgeo.git
SkyScript: add new dataset (#2253)
* SkyScript: add new dataset * Remove print statements * Fix bug * 100% coverage * text -> caption * Simpler tests * Reformat
This commit is contained in:
Родитель
7ad9df8f32
Коммит
451b5a5919
|
@ -191,7 +191,7 @@ Non-geospatial Datasets
|
|||
|
||||
:class:`NonGeoDataset` is designed for datasets that lack geospatial information. These datasets can still be combined using :class:`ConcatDataset <torch.utils.data.ConcatDataset>`.
|
||||
|
||||
.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection
|
||||
.. csv-table:: C = classification, R = regression, S = semantic segmentation, I = instance segmentation, T = time series, CD = change detection, OD = object detection, IC = image captioning
|
||||
:widths: 15 7 15 20 12 11 12 15 13
|
||||
:header-rows: 1
|
||||
:align: center
|
||||
|
@ -397,6 +397,11 @@ SKIPP'D
|
|||
|
||||
.. autoclass:: SKIPPD
|
||||
|
||||
SkyScript
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: SkyScript
|
||||
|
||||
So2Sat
|
||||
^^^^^^
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`SeasoNet`_,S,Sentinel-2,"CC-BY-4.0","1,759,830",33,120x120,10,MSI
|
||||
`SEN12MS`_,S,"Sentinel-1/2, MODIS","CC-BY-4.0","180,662",33,256x256,10,"SAR, MSI"
|
||||
`SKIPP'D`_,R,"Fish-eye","CC-BY-4.0","363,375",-,64x64,-,RGB
|
||||
`SkyScript`_,IC,"NAIP, orthophotos, Planet SkySat, Sentinel-2, Landsat 8--9",MIT,5.2M,-,100--1000,0.1--30,RGB
|
||||
`So2Sat`_,C,Sentinel-1/2,"CC-BY-4.0","400,673",17,32x32,10,"SAR, MSI"
|
||||
`SpaceNet`_,I,WorldView-2/3 Planet Lab Dove,"CC-BY-SA-4.0","1,889--28,728",2,102--900,0.5--4,MSI
|
||||
`SSL4EO`_-L,T,Landsat,"CC0-1.0",1M,-,264x264,30,MSI
|
||||
|
|
|
|
@ -0,0 +1,3 @@
|
|||
filepath,title,title_multi_objects,similarity_CLIP_openai
|
||||
images6/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1
|
||||
images7/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2
|
|
|
@ -0,0 +1,3 @@
|
|||
filepath,title,title_multi_objects,similarity_CLIP_openai
|
||||
images2/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1
|
||||
images3/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2
|
|
|
@ -0,0 +1,3 @@
|
|||
filepath,title,title_multi_objects,similarity_CLIP_openai
|
||||
images4/w779523169_CH_18.jpg,"a satellite image of a beautiful house I will never be able to afford","a satellite image of a beautiful house, surrounded by a yard",0.1
|
||||
images5/w602363451_US_21.jpg,"a satellite image of the last mall in the world","a satellite image of a mall; surrounded by a parking lot",0.2
|
|
|
@ -0,0 +1,28 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
SIZE = 32
|
||||
|
||||
random.seed(0)
|
||||
|
||||
for csv in glob.iglob('*.csv'):
|
||||
captions = pd.read_csv(csv)
|
||||
for jpg in captions['filepath']:
|
||||
os.makedirs(os.path.dirname(jpg), exist_ok=True)
|
||||
width = random.randrange(SIZE)
|
||||
height = random.randrange(SIZE)
|
||||
img = Image.new('RGB', (width, height))
|
||||
img.save(jpg)
|
||||
|
||||
for directory in [f'images{i}' for i in range(2, 8)]:
|
||||
shutil.make_archive(directory, 'zip', '.', directory)
|
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 635 B |
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 635 B |
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 643 B |
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 643 B |
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 643 B |
Двоичный файл не отображается.
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 631 B |
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from matplotlib import pyplot as plt
|
||||
from pytest import MonkeyPatch
|
||||
from torch import Tensor
|
||||
|
||||
from torchgeo.datasets import DatasetNotFoundError, SkyScript
|
||||
|
||||
|
||||
class TestSkyScript:
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> SkyScript:
|
||||
url = os.path.join('tests', 'data', 'skyscript', '{}')
|
||||
monkeypatch.setattr(SkyScript, 'url', url)
|
||||
transforms = nn.Identity()
|
||||
return SkyScript(tmp_path, transforms=transforms, download=True)
|
||||
|
||||
def test_getitem(self, dataset: SkyScript) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], Tensor)
|
||||
assert isinstance(x['caption'], str)
|
||||
|
||||
def test_len(self, dataset: SkyScript) -> None:
|
||||
assert len(dataset) == 2
|
||||
|
||||
def test_already_downloaded(self, dataset: SkyScript) -> None:
|
||||
shutil.rmtree(os.path.join(dataset.root, 'images2'))
|
||||
SkyScript(dataset.root)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
|
||||
SkyScript(tmp_path)
|
||||
|
||||
def test_plot(self, dataset: SkyScript) -> None:
|
||||
x = dataset[0]
|
||||
x['prediction'] = x['caption']
|
||||
dataset.plot(x, suptitle='Test')
|
||||
plt.close()
|
|
@ -100,6 +100,7 @@ from .seco import SeasonalContrastS2
|
|||
from .sen12ms import SEN12MS
|
||||
from .sentinel import Sentinel, Sentinel1, Sentinel2
|
||||
from .skippd import SKIPPD
|
||||
from .skyscript import SkyScript
|
||||
from .so2sat import So2Sat
|
||||
from .south_africa_crop_type import SouthAfricaCropType
|
||||
from .south_america_soybean import SouthAmericaSoybean
|
||||
|
@ -238,6 +239,7 @@ __all__ = (
|
|||
'SeasoNet',
|
||||
'SEN12MS',
|
||||
'SKIPPD',
|
||||
'SkyScript',
|
||||
'So2Sat',
|
||||
'SpaceNet',
|
||||
'SpaceNet1',
|
||||
|
|
|
@ -0,0 +1,187 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""SkyScript dataset."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import Path, download_and_extract_archive, download_url, extract_archive
|
||||
|
||||
|
||||
class SkyScript(NonGeoDataset):
|
||||
"""SkyScript dataset.
|
||||
|
||||
`SkyScript <https://github.com/wangzhecheng/SkyScript>`__ is a large and
|
||||
semantically diverse image-text dataset for remote sensing images. It contains
|
||||
5.2 million remote sensing image-text pairs in total, covering more than 29K
|
||||
distinct semantic tags.
|
||||
|
||||
If you use this dataset in your research, please cite it using the following format:
|
||||
|
||||
* https://arxiv.org/abs/2312.12856
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
url = 'https://opendatasharing.s3.us-west-2.amazonaws.com/SkyScript/{}'
|
||||
|
||||
image_dirs = tuple(f'images{i}' for i in range(2, 8))
|
||||
image_md5s = (
|
||||
'fbfb5f7aa1731f4106fc3ffbd608100a',
|
||||
'ad4fd9fdb9622d1ea360210cb222f2bd',
|
||||
'aeeb41e830304c74b14b5ffc1fc8e8c3',
|
||||
'02ee7e0e59f9ac1c87b678a155e1f1df',
|
||||
'350475f1e7fb996152fa16db891b4142',
|
||||
'5e2fbf3e9262b36e30b458ec9a1df625',
|
||||
)
|
||||
|
||||
#: Can be modified in subclasses to change train/val/test split
|
||||
caption_files: ClassVar[dict[str, str]] = {
|
||||
'train': 'SkyScript_train_top30pct_filtered_by_CLIP_openai.csv',
|
||||
'val': 'SkyScript_val_5K_filtered_by_CLIP_openai.csv',
|
||||
'test': 'SkyScript_test_30K_filtered_by_CLIP_openai.csv',
|
||||
}
|
||||
caption_md5s: ClassVar[dict[str, str]] = {
|
||||
'train': '05b362e43a852667b5374c9a5ae53f8e',
|
||||
'val': 'c8d278fd29b754361989d5e7a6608f69',
|
||||
'test': '0135d9b49ce6751360912a4353e809dc',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Path = 'data',
|
||||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new SkyScript instance.
|
||||
|
||||
Args:
|
||||
root: Root directory where dataset can be found.
|
||||
split: One of 'train', 'val', 'test'.
|
||||
transforms: A function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version.
|
||||
download: If True, download dataset and store it in the root directory.
|
||||
checksum: If True, check the MD5 of the downloaded files (may be slow).
|
||||
|
||||
Raises:
|
||||
AssertionError: If *split* is invalid.
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.caption_files
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.download = download
|
||||
self.checksum = checksum
|
||||
|
||||
self._verify()
|
||||
|
||||
self.captions = pd.read_csv(os.path.join(self.root, self.caption_files[split]))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of images in the dataset.
|
||||
|
||||
Returns:
|
||||
Length of the dataset.
|
||||
"""
|
||||
return len(self.captions)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: Index to return.
|
||||
|
||||
Returns:
|
||||
A dict containing image and caption at index.
|
||||
"""
|
||||
filepath, title = self.captions.iloc[index][:2]
|
||||
|
||||
with Image.open(os.path.join(self.root, filepath)) as img:
|
||||
array = np.array(img, dtype=np.float32)
|
||||
array = rearrange(array, 'h w c -> c h w')
|
||||
image = torch.from_numpy(array)
|
||||
|
||||
sample = {'image': image, 'caption': title}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
md5: str | None
|
||||
for directory, md5 in zip(self.image_dirs, self.image_md5s):
|
||||
# Check if the extracted files already exist
|
||||
if os.path.isdir(os.path.join(self.root, directory)):
|
||||
continue
|
||||
|
||||
# Check if the zip files have already been downloaded
|
||||
if os.path.isfile(os.path.join(self.root, f'{directory}.zip')):
|
||||
extract_archive(os.path.join(self.root, f'{directory}.zip'))
|
||||
continue
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download the dataset
|
||||
url = self.url.format(f'{directory}.zip')
|
||||
md5 = md5 if self.checksum else None
|
||||
download_and_extract_archive(url, self.root, md5=md5)
|
||||
|
||||
# Download the caption file
|
||||
if self.download:
|
||||
url = self.url.format(self.caption_files[self.split])
|
||||
md5 = self.caption_md5s[self.split] if self.checksum else None
|
||||
download_url(url, self.root, md5=md5)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: str | None = None,
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`RasterDataset.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional string to use as a suptitle
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
image = rearrange(sample['image'], 'c h w -> h w c') / 255
|
||||
ax.imshow(image)
|
||||
ax.axis('off')
|
||||
|
||||
if show_titles:
|
||||
title = sample['caption']
|
||||
if 'prediction' in sample:
|
||||
title += '\n' + sample['prediction']
|
||||
ax.set_title(title)
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче