Tropical Cyclone: radiant mlhub -> source cooperative (#2068)

* Tropical Cyclone: radiant mlhub -> source cooperative

* Update azcopy unit tests
This commit is contained in:
Adam J. Stewart 2024-07-10 10:35:11 +02:00 коммит произвёл GitHub
Родитель 11ec656c2d
Коммит ab258bfc96
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
51 изменённых файлов: 127 добавлений и 269 удалений

32
tests/data/cyclone/data.py Executable file
Просмотреть файл

@ -0,0 +1,32 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import numpy as np
import pandas as pd
from PIL import Image
DTYPE = np.uint8
SIZE = 2
np.random.seed(0)
for split in ['train', 'test']:
os.makedirs(split, exist_ok=True)
filename = split
if split == 'train':
filename = 'training'
features = pd.read_csv(f'{filename}_set_features.csv')
for image_id, _, _, ocean in features.values:
size = (SIZE, SIZE)
if ocean % 2 == 0:
size = (SIZE * 2, SIZE * 2, 3)
arr = np.random.randint(np.iinfo(DTYPE).max, size=size, dtype=DTYPE)
img = Image.fromarray(arr)
img.save(os.path.join(split, f'{image_id}.jpg'))

Двоичный файл не отображается.

Просмотреть файл

@ -1,24 +0,0 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -1 +0,0 @@
{"wind_speed": "34"}

Просмотреть файл

@ -1 +0,0 @@
{"wind_speed": "34"}

Просмотреть файл

@ -1 +0,0 @@
{"wind_speed": "34"}

Просмотреть файл

@ -1 +0,0 @@
{"wind_speed": "34"}

Просмотреть файл

@ -1 +0,0 @@
{"wind_speed": "34"}

Двоичный файл не отображается.

Просмотреть файл

@ -1,24 +0,0 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_test_source_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_test_source_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "a", "relative_time": "0", "ocean": "2"}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "b", "relative_time": "0", "ocean": "2"}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "c", "relative_time": "0", "ocean": "2"}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "d", "relative_time": "0", "ocean": "2"}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "e", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

Просмотреть файл

@ -1,24 +0,0 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json",
"rel": "item"
}
]
}

Двоичный файл не отображается.

Просмотреть файл

@ -1,24 +0,0 @@
{
"links": [
{
"href": "nasa_tropical_storm_competition_train_source_a_000/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_b_001/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_c_002/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_d_003/stac.json",
"rel": "item"
},
{
"href": "nasa_tropical_storm_competition_train_source_e_004/stac.json",
"rel": "item"
}
]
}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "a", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "b", "relative_time": "0", "ocean": "2"}

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "c", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "d", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

Просмотреть файл

@ -1 +0,0 @@
{"storm_id": "e", "relative_time": "0", "ocean": "2"}

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

Двоичные данные
tests/data/cyclone/test/aaa_000.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 680 B

Двоичные данные
tests/data/cyclone/test/bbb_111.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 357 B

Двоичные данные
tests/data/cyclone/test/ccc_222.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 671 B

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

После

Ширина:  |  Высота:  |  Размер: 347 B

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 665 B

Просмотреть файл

@ -0,0 +1,6 @@
Image ID,Storm ID,Relative Time,Ocean
aaa_000,aaa,0,0
bbb_111,bbb,1,1
ccc_222,ccc,2,2
ddd_333,ddd,3,3
eee_444,eee,4,4
1 Image ID Storm ID Relative Time Ocean
2 aaa_000 aaa 0 0
3 bbb_111 bbb 1 1
4 ccc_222 ccc 2 2
5 ddd_333 ddd 3 3
6 eee_444 eee 4 4

Просмотреть файл

@ -0,0 +1,6 @@
Image ID,Wind Speed
aaa_000,0
bbb_111,1
ccc_222,2
ddd_333,3
eee_444,4
1 Image ID Wind Speed
2 aaa_000 0
3 bbb_111 1
4 ccc_222 2
5 ddd_333 3
6 eee_444 4

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

После

Ширина:  |  Высота:  |  Размер: 352 B

Двоичные данные
tests/data/cyclone/train/ggg_666.jpg Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 672 B

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

После

Ширина:  |  Высота:  |  Размер: 344 B

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 631 B

После

Ширина:  |  Высота:  |  Размер: 668 B

Двоичный файл не отображается.

До

Ширина:  |  Высота:  |  Размер: 333 B

После

Ширина:  |  Высота:  |  Размер: 342 B

Просмотреть файл

@ -0,0 +1,6 @@
Image ID,Storm ID,Relative Time,Ocean
fff_555,fff,5,5
ggg_666,ggg,6,6
hhh_777,hhh,7,7
iii_888,iii,8,8
jjj_999,jjj,9,9
1 Image ID Storm ID Relative Time Ocean
2 fff_555 fff 5 5
3 ggg_666 ggg 6 6
4 hhh_777 hhh 7 7
5 iii_888 iii 8 8
6 jjj_999 jjj 9 9

Просмотреть файл

@ -0,0 +1,6 @@
Image ID,Wind Speed
fff_555,5
ggg_666,6
hhh_777,7
iii_888,8
jjj_999,9
1 Image ID Wind Speed
2 fff_555 5
3 ggg_666 6
4 hhh_777 7
5 iii_888 8
6 jjj_999 9

Просмотреть файл

@ -1,9 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import glob
import os
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
@ -15,52 +13,33 @@ from pytest import MonkeyPatch
from torch.utils.data import ConcatDataset
from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone
class Collection:
def download(self, output_dir: str, **kwargs: str) -> None:
for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')):
shutil.copy(tarball, output_dir)
def fetch(collection_id: str, **kwargs: str) -> Collection:
return Collection()
from torchgeo.datasets.utils import Executable
class TestTropicalCyclone:
@pytest.fixture(params=['train', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
self,
request: SubRequest,
azcopy: Executable,
monkeypatch: MonkeyPatch,
tmp_path: Path,
) -> TropicalCyclone:
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
md5s = {
'train': {
'source': '2b818e0a0873728dabf52c7054a0ce4c',
'labels': 'c3c2b6d02c469c5519f4add4f9132712',
},
'test': {
'source': 'bc07c519ddf3ce88857435ddddf98a16',
'labels': '3ca4243eff39b87c73e05ec8db1824bf',
},
}
monkeypatch.setattr(TropicalCyclone, 'md5s', md5s)
monkeypatch.setattr(TropicalCyclone, 'size', 1)
url = os.path.join('tests', 'data', 'cyclone')
monkeypatch.setattr(TropicalCyclone, 'url', url)
monkeypatch.setattr(TropicalCyclone, 'size', 2)
root = str(tmp_path)
split = request.param
transforms = nn.Identity()
return TropicalCyclone(
root, split, transforms, download=True, api_key='', checksum=True
)
return TropicalCyclone(root, split, transforms, download=True)
@pytest.mark.parametrize('index', [0, 1])
def test_getitem(self, dataset: TropicalCyclone, index: int) -> None:
x = dataset[index]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['storm_id'], str)
assert isinstance(x['relative_time'], int)
assert isinstance(x['ocean'], int)
assert isinstance(x['relative_time'], torch.Tensor)
assert isinstance(x['ocean'], torch.Tensor)
assert isinstance(x['label'], torch.Tensor)
assert x['image'].shape == (3, dataset.size, dataset.size)
@ -73,7 +52,7 @@ class TestTropicalCyclone:
assert len(ds) == 10
def test_already_downloaded(self, dataset: TropicalCyclone) -> None:
TropicalCyclone(root=dataset.root, download=True, api_key='')
TropicalCyclone(root=dataset.root, download=True)
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
@ -84,10 +63,9 @@ class TestTropicalCyclone:
TropicalCyclone(str(tmp_path))
def test_plot(self, dataset: TropicalCyclone) -> None:
dataset.plot(dataset[0], suptitle='Test')
plt.close()
sample = dataset[0]
dataset.plot(sample, suptitle='Test')
plt.close()
sample['prediction'] = sample['label']
dataset.plot(sample)
plt.close()

Просмотреть файл

@ -597,7 +597,7 @@ def test_lazy_import_missing(name: str) -> None:
def test_azcopy(tmp_path: Path, azcopy: Executable) -> None:
source = os.path.join('tests', 'data', 'cyclone')
azcopy('sync', source, tmp_path, '--recursive=true')
assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels')
assert os.path.exists(tmp_path / 'test')
def test_which() -> None:

Просмотреть файл

@ -43,18 +43,11 @@ class TropicalCycloneDataModule(NonGeoDataModule):
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ['fit', 'validate']:
self.dataset = TropicalCyclone(split='train', **self.kwargs)
storm_ids = []
for item in self.dataset.collection:
storm_id = item['href'].split('/')[0].split('_')[-2]
storm_ids.append(storm_id)
dataset = TropicalCyclone(split='train', **self.kwargs)
train_indices, val_indices = group_shuffle_split(
storm_ids, test_size=0.2, random_state=0
dataset.features['Storm ID'], test_size=0.2, random_state=0
)
self.train_dataset = Subset(self.dataset, train_indices)
self.val_dataset = Subset(self.dataset, val_indices)
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
if stage in ['test']:
self.test_dataset = TropicalCyclone(split='test', **self.kwargs)

Просмотреть файл

@ -3,7 +3,6 @@
"""Tropical Cyclone Wind Estimation Competition dataset."""
import json
import os
from collections.abc import Callable
from functools import lru_cache
@ -11,6 +10,7 @@ from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.figure import Figure
from PIL import Image
@ -18,7 +18,7 @@ from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
from .utils import which
class TropicalCyclone(NonGeoDataset):
@ -26,10 +26,9 @@ class TropicalCyclone(NonGeoDataset):
A collection of tropical storms in the Atlantic and East Pacific Oceans from 2000 to
2019 with corresponding maximum sustained surface wind speed. This dataset is split
into training and test categories for the purpose of a competition.
See https://www.drivendata.org/competitions/72/predict-wind-speeds/ for more
information about the competition.
into training and test categories for the purpose of a competition. Read more about
the competition here:
https://www.drivendata.org/competitions/72/predict-wind-speeds/.
If you use this dataset in your research, please cite the following paper:
@ -39,31 +38,17 @@ class TropicalCyclone(NonGeoDataset):
This dataset requires the following additional library to be installed:
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
imagery and labels from the Radiant Earth MLHub
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
dataset from Source Cooperative.
.. versionchanged:: 0.4
Class name changed from TropicalCycloneWindEstimation to TropicalCyclone
to be consistent with TropicalCycloneDataModule.
"""
collection_id = 'nasa_tropical_storm_competition'
collection_ids = [
'nasa_tropical_storm_competition_train_source',
'nasa_tropical_storm_competition_test_source',
'nasa_tropical_storm_competition_train_labels',
'nasa_tropical_storm_competition_test_labels',
]
md5s = {
'train': {
'source': '97e913667a398704ea8d28196d91dad6',
'labels': '97d02608b74c82ffe7496a9404a30413',
},
'test': {
'source': '8d88099e4b310feb7781d776a6e1dcef',
'labels': 'd910c430f90153c1f78a99cbc08e7bd0',
},
}
url = (
'https://radiantearth.blob.core.windows.net/mlhub/nasa-tropical-storm-challenge'
)
size = 366
def __init__(
@ -72,10 +57,8 @@ class TropicalCyclone(NonGeoDataset):
split: str = 'train',
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
api_key: str | None = None,
checksum: bool = False,
) -> None:
"""Initialize a new Tropical Cyclone Wind Estimation Competition Dataset.
"""Initialize a new TropicalCyclone instance.
Args:
root: root directory where dataset can be found
@ -83,30 +66,26 @@ class TropicalCyclone(NonGeoDataset):
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if ``split`` argument is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.md5s
assert split in {'train', 'test'}
self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum
self.download = download
if download:
self._download(api_key)
self.filename = f'{split}_set'
if split == 'train':
self.filename = f'{split}ing_set'
if not self._check_integrity():
raise DatasetNotFoundError(self)
self._verify()
output_dir = '_'.join([self.collection_id, split, 'source'])
filename = os.path.join(root, output_dir, 'collection.json')
with open(filename) as f:
self.collection = json.load(f)['links']
self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv'))
self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv'))
def __getitem__(self, index: int) -> dict[str, Any]:
"""Return an index within the dataset.
@ -117,15 +96,14 @@ class TropicalCyclone(NonGeoDataset):
Returns:
data, labels, field ids, and metadata at that index
"""
source_id = os.path.split(self.collection[index]['href'])[0]
directory = os.path.join(
self.root,
'_'.join([self.collection_id, self.split, '{0}']),
source_id.replace('source', '{0}'),
)
sample = {
'relative_time': torch.tensor(self.features.iat[index, 2]),
'ocean': torch.tensor(self.features.iat[index, 3]),
'label': torch.tensor(self.labels.iat[index, 1]),
}
sample: dict[str, Any] = {'image': self._load_image(directory)}
sample.update(self._load_features(directory))
image_id = self.labels.iat[index, 0]
sample['image'] = self._load_image(image_id)
if self.transforms is not None:
sample = self.transforms(sample)
@ -138,19 +116,19 @@ class TropicalCyclone(NonGeoDataset):
Returns:
length of the dataset
"""
return len(self.collection)
return len(self.labels)
@lru_cache
def _load_image(self, directory: str) -> Tensor:
def _load_image(self, image_id: str) -> Tensor:
"""Load a single image.
Args:
directory: directory containing image
image_id: Filename of the image.
Returns:
the image
"""
filename = os.path.join(directory.format('source'), 'image.jpg')
filename = os.path.join(self.root, self.split, f'{image_id}.jpg')
with Image.open(filename) as img:
if img.height != self.size or img.width != self.size:
# Moved in PIL 9.1.0
@ -164,61 +142,30 @@ class TropicalCyclone(NonGeoDataset):
tensor = tensor.permute((2, 0, 1)).float()
return tensor
def _load_features(self, directory: str) -> dict[str, Any]:
"""Load features for a single image.
Args:
directory: directory containing image
Returns:
the features
"""
filename = os.path.join(directory.format('source'), 'features.json')
with open(filename) as f:
features: dict[str, Any] = json.load(f)
filename = os.path.join(directory.format('labels'), 'labels.json')
with open(filename) as f:
features.update(json.load(f))
features['relative_time'] = int(features['relative_time'])
features['ocean'] = int(features['ocean'])
features['label'] = torch.tensor(int(features['wind_speed'])).float()
return features
def _check_integrity(self) -> bool:
"""Check integrity of dataset.
Returns:
True if dataset files are found and/or MD5s match, else False
"""
for split, resources in self.md5s.items():
for resource_type, md5 in resources.items():
filename = '_'.join([self.collection_id, split, resource_type])
filename = os.path.join(self.root, filename + '.tar.gz')
if not check_integrity(filename, md5 if self.checksum else None):
return False
return True
def _download(self, api_key: str | None = None) -> None:
"""Download the dataset and extract it.
Args:
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
"""
if self._check_integrity():
print('Files already downloaded and verified')
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
# Check if the files already exist
files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv']
exists = [os.path.exists(os.path.join(self.root, file)) for file in files]
if all(exists):
return
for collection_id in self.collection_ids:
download_radiant_mlhub_collection(collection_id, self.root, api_key)
# Check if the user requested to download the dataset
if not self.download:
raise DatasetNotFoundError(self)
for split, resources in self.md5s.items():
for resource_type in resources:
filename = '_'.join([self.collection_id, split, resource_type])
filename = os.path.join(self.root, filename) + '.tar.gz'
extract_archive(filename, self.root)
# Download the dataset
self._download()
def _download(self) -> None:
"""Download the dataset."""
directory = os.path.join(self.root, self.split)
os.makedirs(directory, exist_ok=True)
azcopy = which('azcopy')
azcopy('sync', f'{self.url}/{self.split}', directory, '--recursive=true')
files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv']
for file in files:
azcopy('copy', f'{self.url}/{file}', self.root)
def plot(
self,