Tropical Cyclone: radiant mlhub -> source cooperative (#2068)
* Tropical Cyclone: radiant mlhub -> source cooperative * Update azcopy unit tests
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
|
||||
DTYPE = np.uint8
|
||||
SIZE = 2
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
for split in ['train', 'test']:
|
||||
os.makedirs(split, exist_ok=True)
|
||||
|
||||
filename = split
|
||||
if split == 'train':
|
||||
filename = 'training'
|
||||
|
||||
features = pd.read_csv(f'{filename}_set_features.csv')
|
||||
for image_id, _, _, ocean in features.values:
|
||||
size = (SIZE, SIZE)
|
||||
if ocean % 2 == 0:
|
||||
size = (SIZE * 2, SIZE * 2, 3)
|
||||
|
||||
arr = np.random.randint(np.iinfo(DTYPE).max, size=size, dtype=DTYPE)
|
||||
img = Image.fromarray(arr)
|
||||
img.save(os.path.join(split, f'{image_id}.jpg'))
|
|
@ -1,24 +0,0 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1,24 +0,0 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_test_source_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "a", "relative_time": "0", "ocean": "2"}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "b", "relative_time": "0", "ocean": "2"}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "c", "relative_time": "0", "ocean": "2"}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "d", "relative_time": "0", "ocean": "2"}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "e", "relative_time": "0", "ocean": "2"}
|
|
@ -1,24 +0,0 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1 +0,0 @@
|
|||
{"wind_speed": "34"}
|
|
@ -1,24 +0,0 @@
|
|||
{
|
||||
"links": [
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_a_000/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_b_001/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_c_002/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_d_003/stac.json",
|
||||
"rel": "item"
|
||||
},
|
||||
{
|
||||
"href": "nasa_tropical_storm_competition_train_source_e_004/stac.json",
|
||||
"rel": "item"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "a", "relative_time": "0", "ocean": "2"}
|
До Ширина: | Высота: | Размер: 333 B |
|
@ -1 +0,0 @@
|
|||
{"storm_id": "b", "relative_time": "0", "ocean": "2"}
|
|
@ -1 +0,0 @@
|
|||
{"storm_id": "c", "relative_time": "0", "ocean": "2"}
|
До Ширина: | Высота: | Размер: 333 B |
|
@ -1 +0,0 @@
|
|||
{"storm_id": "d", "relative_time": "0", "ocean": "2"}
|
До Ширина: | Высота: | Размер: 333 B |
|
@ -1 +0,0 @@
|
|||
{"storm_id": "e", "relative_time": "0", "ocean": "2"}
|
До Ширина: | Высота: | Размер: 333 B |
После Ширина: | Высота: | Размер: 680 B |
После Ширина: | Высота: | Размер: 357 B |
После Ширина: | Высота: | Размер: 671 B |
До Ширина: | Высота: | Размер: 333 B После Ширина: | Высота: | Размер: 347 B |
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 665 B |
|
@ -0,0 +1,6 @@
|
|||
Image ID,Storm ID,Relative Time,Ocean
|
||||
aaa_000,aaa,0,0
|
||||
bbb_111,bbb,1,1
|
||||
ccc_222,ccc,2,2
|
||||
ddd_333,ddd,3,3
|
||||
eee_444,eee,4,4
|
|
|
@ -0,0 +1,6 @@
|
|||
Image ID,Wind Speed
|
||||
aaa_000,0
|
||||
bbb_111,1
|
||||
ccc_222,2
|
||||
ddd_333,3
|
||||
eee_444,4
|
|
До Ширина: | Высота: | Размер: 333 B После Ширина: | Высота: | Размер: 352 B |
После Ширина: | Высота: | Размер: 672 B |
До Ширина: | Высота: | Размер: 333 B После Ширина: | Высота: | Размер: 344 B |
До Ширина: | Высота: | Размер: 631 B После Ширина: | Высота: | Размер: 668 B |
До Ширина: | Высота: | Размер: 333 B После Ширина: | Высота: | Размер: 342 B |
|
@ -0,0 +1,6 @@
|
|||
Image ID,Storm ID,Relative Time,Ocean
|
||||
fff_555,fff,5,5
|
||||
ggg_666,ggg,6,6
|
||||
hhh_777,hhh,7,7
|
||||
iii_888,iii,8,8
|
||||
jjj_999,jjj,9,9
|
|
|
@ -0,0 +1,6 @@
|
|||
Image ID,Wind Speed
|
||||
fff_555,5
|
||||
ggg_666,6
|
||||
hhh_777,7
|
||||
iii_888,8
|
||||
jjj_999,9
|
|
|
@ -1,9 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -15,52 +13,33 @@ from pytest import MonkeyPatch
|
|||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone
|
||||
|
||||
|
||||
class Collection:
|
||||
def download(self, output_dir: str, **kwargs: str) -> None:
|
||||
for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')):
|
||||
shutil.copy(tarball, output_dir)
|
||||
|
||||
|
||||
def fetch(collection_id: str, **kwargs: str) -> Collection:
|
||||
return Collection()
|
||||
from torchgeo.datasets.utils import Executable
|
||||
|
||||
|
||||
class TestTropicalCyclone:
|
||||
@pytest.fixture(params=['train', 'test'])
|
||||
def dataset(
|
||||
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
|
||||
self,
|
||||
request: SubRequest,
|
||||
azcopy: Executable,
|
||||
monkeypatch: MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> TropicalCyclone:
|
||||
radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3')
|
||||
monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch)
|
||||
md5s = {
|
||||
'train': {
|
||||
'source': '2b818e0a0873728dabf52c7054a0ce4c',
|
||||
'labels': 'c3c2b6d02c469c5519f4add4f9132712',
|
||||
},
|
||||
'test': {
|
||||
'source': 'bc07c519ddf3ce88857435ddddf98a16',
|
||||
'labels': '3ca4243eff39b87c73e05ec8db1824bf',
|
||||
},
|
||||
}
|
||||
monkeypatch.setattr(TropicalCyclone, 'md5s', md5s)
|
||||
monkeypatch.setattr(TropicalCyclone, 'size', 1)
|
||||
url = os.path.join('tests', 'data', 'cyclone')
|
||||
monkeypatch.setattr(TropicalCyclone, 'url', url)
|
||||
monkeypatch.setattr(TropicalCyclone, 'size', 2)
|
||||
root = str(tmp_path)
|
||||
split = request.param
|
||||
transforms = nn.Identity()
|
||||
return TropicalCyclone(
|
||||
root, split, transforms, download=True, api_key='', checksum=True
|
||||
)
|
||||
return TropicalCyclone(root, split, transforms, download=True)
|
||||
|
||||
@pytest.mark.parametrize('index', [0, 1])
|
||||
def test_getitem(self, dataset: TropicalCyclone, index: int) -> None:
|
||||
x = dataset[index]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x['image'], torch.Tensor)
|
||||
assert isinstance(x['storm_id'], str)
|
||||
assert isinstance(x['relative_time'], int)
|
||||
assert isinstance(x['ocean'], int)
|
||||
assert isinstance(x['relative_time'], torch.Tensor)
|
||||
assert isinstance(x['ocean'], torch.Tensor)
|
||||
assert isinstance(x['label'], torch.Tensor)
|
||||
assert x['image'].shape == (3, dataset.size, dataset.size)
|
||||
|
||||
|
@ -73,7 +52,7 @@ class TestTropicalCyclone:
|
|||
assert len(ds) == 10
|
||||
|
||||
def test_already_downloaded(self, dataset: TropicalCyclone) -> None:
|
||||
TropicalCyclone(root=dataset.root, download=True, api_key='')
|
||||
TropicalCyclone(root=dataset.root, download=True)
|
||||
|
||||
def test_invalid_split(self) -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
|
@ -84,10 +63,9 @@ class TestTropicalCyclone:
|
|||
TropicalCyclone(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: TropicalCyclone) -> None:
|
||||
dataset.plot(dataset[0], suptitle='Test')
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
dataset.plot(sample, suptitle='Test')
|
||||
plt.close()
|
||||
sample['prediction'] = sample['label']
|
||||
dataset.plot(sample)
|
||||
plt.close()
|
||||
|
|
|
@ -597,7 +597,7 @@ def test_lazy_import_missing(name: str) -> None:
|
|||
def test_azcopy(tmp_path: Path, azcopy: Executable) -> None:
|
||||
source = os.path.join('tests', 'data', 'cyclone')
|
||||
azcopy('sync', source, tmp_path, '--recursive=true')
|
||||
assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels')
|
||||
assert os.path.exists(tmp_path / 'test')
|
||||
|
||||
|
||||
def test_which() -> None:
|
||||
|
|
|
@ -43,18 +43,11 @@ class TropicalCycloneDataModule(NonGeoDataModule):
|
|||
stage: Either 'fit', 'validate', 'test', or 'predict'.
|
||||
"""
|
||||
if stage in ['fit', 'validate']:
|
||||
self.dataset = TropicalCyclone(split='train', **self.kwargs)
|
||||
|
||||
storm_ids = []
|
||||
for item in self.dataset.collection:
|
||||
storm_id = item['href'].split('/')[0].split('_')[-2]
|
||||
storm_ids.append(storm_id)
|
||||
|
||||
dataset = TropicalCyclone(split='train', **self.kwargs)
|
||||
train_indices, val_indices = group_shuffle_split(
|
||||
storm_ids, test_size=0.2, random_state=0
|
||||
dataset.features['Storm ID'], test_size=0.2, random_state=0
|
||||
)
|
||||
|
||||
self.train_dataset = Subset(self.dataset, train_indices)
|
||||
self.val_dataset = Subset(self.dataset, val_indices)
|
||||
self.train_dataset = Subset(dataset, train_indices)
|
||||
self.val_dataset = Subset(dataset, val_indices)
|
||||
if stage in ['test']:
|
||||
self.test_dataset = TropicalCyclone(split='test', **self.kwargs)
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
"""Tropical Cyclone Wind Estimation Competition dataset."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
|
@ -11,6 +10,7 @@ from typing import Any
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from .errors import DatasetNotFoundError
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive
|
||||
from .utils import which
|
||||
|
||||
|
||||
class TropicalCyclone(NonGeoDataset):
|
||||
|
@ -26,10 +26,9 @@ class TropicalCyclone(NonGeoDataset):
|
|||
|
||||
A collection of tropical storms in the Atlantic and East Pacific Oceans from 2000 to
|
||||
2019 with corresponding maximum sustained surface wind speed. This dataset is split
|
||||
into training and test categories for the purpose of a competition.
|
||||
|
||||
See https://www.drivendata.org/competitions/72/predict-wind-speeds/ for more
|
||||
information about the competition.
|
||||
into training and test categories for the purpose of a competition. Read more about
|
||||
the competition here:
|
||||
https://www.drivendata.org/competitions/72/predict-wind-speeds/.
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
|
@ -39,31 +38,17 @@ class TropicalCyclone(NonGeoDataset):
|
|||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `radiant-mlhub <https://pypi.org/project/radiant-mlhub/>`_ to download the
|
||||
imagery and labels from the Radiant Earth MLHub
|
||||
* `azcopy <https://github.com/Azure/azure-storage-azcopy>`_: to download the
|
||||
dataset from Source Cooperative.
|
||||
|
||||
.. versionchanged:: 0.4
|
||||
Class name changed from TropicalCycloneWindEstimation to TropicalCyclone
|
||||
to be consistent with TropicalCycloneDataModule.
|
||||
"""
|
||||
|
||||
collection_id = 'nasa_tropical_storm_competition'
|
||||
collection_ids = [
|
||||
'nasa_tropical_storm_competition_train_source',
|
||||
'nasa_tropical_storm_competition_test_source',
|
||||
'nasa_tropical_storm_competition_train_labels',
|
||||
'nasa_tropical_storm_competition_test_labels',
|
||||
]
|
||||
md5s = {
|
||||
'train': {
|
||||
'source': '97e913667a398704ea8d28196d91dad6',
|
||||
'labels': '97d02608b74c82ffe7496a9404a30413',
|
||||
},
|
||||
'test': {
|
||||
'source': '8d88099e4b310feb7781d776a6e1dcef',
|
||||
'labels': 'd910c430f90153c1f78a99cbc08e7bd0',
|
||||
},
|
||||
}
|
||||
url = (
|
||||
'https://radiantearth.blob.core.windows.net/mlhub/nasa-tropical-storm-challenge'
|
||||
)
|
||||
size = 366
|
||||
|
||||
def __init__(
|
||||
|
@ -72,10 +57,8 @@ class TropicalCyclone(NonGeoDataset):
|
|||
split: str = 'train',
|
||||
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
|
||||
download: bool = False,
|
||||
api_key: str | None = None,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new Tropical Cyclone Wind Estimation Competition Dataset.
|
||||
"""Initialize a new TropicalCyclone instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
|
@ -83,30 +66,26 @@ class TropicalCyclone(NonGeoDataset):
|
|||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
AssertionError: if ``split`` argument is invalid
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
"""
|
||||
assert split in self.md5s
|
||||
assert split in {'train', 'test'}
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
self.download = download
|
||||
|
||||
if download:
|
||||
self._download(api_key)
|
||||
self.filename = f'{split}_set'
|
||||
if split == 'train':
|
||||
self.filename = f'{split}ing_set'
|
||||
|
||||
if not self._check_integrity():
|
||||
raise DatasetNotFoundError(self)
|
||||
self._verify()
|
||||
|
||||
output_dir = '_'.join([self.collection_id, split, 'source'])
|
||||
filename = os.path.join(root, output_dir, 'collection.json')
|
||||
with open(filename) as f:
|
||||
self.collection = json.load(f)['links']
|
||||
self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv'))
|
||||
self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv'))
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
"""Return an index within the dataset.
|
||||
|
@ -117,15 +96,14 @@ class TropicalCyclone(NonGeoDataset):
|
|||
Returns:
|
||||
data, labels, field ids, and metadata at that index
|
||||
"""
|
||||
source_id = os.path.split(self.collection[index]['href'])[0]
|
||||
directory = os.path.join(
|
||||
self.root,
|
||||
'_'.join([self.collection_id, self.split, '{0}']),
|
||||
source_id.replace('source', '{0}'),
|
||||
)
|
||||
sample = {
|
||||
'relative_time': torch.tensor(self.features.iat[index, 2]),
|
||||
'ocean': torch.tensor(self.features.iat[index, 3]),
|
||||
'label': torch.tensor(self.labels.iat[index, 1]),
|
||||
}
|
||||
|
||||
sample: dict[str, Any] = {'image': self._load_image(directory)}
|
||||
sample.update(self._load_features(directory))
|
||||
image_id = self.labels.iat[index, 0]
|
||||
sample['image'] = self._load_image(image_id)
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
@ -138,19 +116,19 @@ class TropicalCyclone(NonGeoDataset):
|
|||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.collection)
|
||||
return len(self.labels)
|
||||
|
||||
@lru_cache
|
||||
def _load_image(self, directory: str) -> Tensor:
|
||||
def _load_image(self, image_id: str) -> Tensor:
|
||||
"""Load a single image.
|
||||
|
||||
Args:
|
||||
directory: directory containing image
|
||||
image_id: Filename of the image.
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
filename = os.path.join(directory.format('source'), 'image.jpg')
|
||||
filename = os.path.join(self.root, self.split, f'{image_id}.jpg')
|
||||
with Image.open(filename) as img:
|
||||
if img.height != self.size or img.width != self.size:
|
||||
# Moved in PIL 9.1.0
|
||||
|
@ -164,61 +142,30 @@ class TropicalCyclone(NonGeoDataset):
|
|||
tensor = tensor.permute((2, 0, 1)).float()
|
||||
return tensor
|
||||
|
||||
def _load_features(self, directory: str) -> dict[str, Any]:
|
||||
"""Load features for a single image.
|
||||
|
||||
Args:
|
||||
directory: directory containing image
|
||||
|
||||
Returns:
|
||||
the features
|
||||
"""
|
||||
filename = os.path.join(directory.format('source'), 'features.json')
|
||||
with open(filename) as f:
|
||||
features: dict[str, Any] = json.load(f)
|
||||
|
||||
filename = os.path.join(directory.format('labels'), 'labels.json')
|
||||
with open(filename) as f:
|
||||
features.update(json.load(f))
|
||||
|
||||
features['relative_time'] = int(features['relative_time'])
|
||||
features['ocean'] = int(features['ocean'])
|
||||
features['label'] = torch.tensor(int(features['wind_speed'])).float()
|
||||
|
||||
return features
|
||||
|
||||
def _check_integrity(self) -> bool:
|
||||
"""Check integrity of dataset.
|
||||
|
||||
Returns:
|
||||
True if dataset files are found and/or MD5s match, else False
|
||||
"""
|
||||
for split, resources in self.md5s.items():
|
||||
for resource_type, md5 in resources.items():
|
||||
filename = '_'.join([self.collection_id, split, resource_type])
|
||||
filename = os.path.join(self.root, filename + '.tar.gz')
|
||||
if not check_integrity(filename, md5 if self.checksum else None):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self, api_key: str | None = None) -> None:
|
||||
"""Download the dataset and extract it.
|
||||
|
||||
Args:
|
||||
api_key: a RadiantEarth MLHub API key to use for downloading the dataset
|
||||
"""
|
||||
if self._check_integrity():
|
||||
print('Files already downloaded and verified')
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if the files already exist
|
||||
files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv']
|
||||
exists = [os.path.exists(os.path.join(self.root, file)) for file in files]
|
||||
if all(exists):
|
||||
return
|
||||
|
||||
for collection_id in self.collection_ids:
|
||||
download_radiant_mlhub_collection(collection_id, self.root, api_key)
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
for split, resources in self.md5s.items():
|
||||
for resource_type in resources:
|
||||
filename = '_'.join([self.collection_id, split, resource_type])
|
||||
filename = os.path.join(self.root, filename) + '.tar.gz'
|
||||
extract_archive(filename, self.root)
|
||||
# Download the dataset
|
||||
self._download()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset."""
|
||||
directory = os.path.join(self.root, self.split)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
azcopy = which('azcopy')
|
||||
azcopy('sync', f'{self.url}/{self.split}', directory, '--recursive=true')
|
||||
files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv']
|
||||
for file in files:
|
||||
azcopy('copy', f'{self.url}/{file}', self.root)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
|
|