diff --git a/tests/data/cyclone/data.py b/tests/data/cyclone/data.py new file mode 100755 index 000000000..2ea0f7a42 --- /dev/null +++ b/tests/data/cyclone/data.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os + +import numpy as np +import pandas as pd +from PIL import Image + +DTYPE = np.uint8 +SIZE = 2 + +np.random.seed(0) + +for split in ['train', 'test']: + os.makedirs(split, exist_ok=True) + + filename = split + if split == 'train': + filename = 'training' + + features = pd.read_csv(f'{filename}_set_features.csv') + for image_id, _, _, ocean in features.values: + size = (SIZE, SIZE) + if ocean % 2 == 0: + size = (SIZE * 2, SIZE * 2, 3) + + arr = np.random.randint(np.iinfo(DTYPE).max, size=size, dtype=DTYPE) + img = Image.fromarray(arr) + img.save(os.path.join(split, f'{image_id}.jpg')) diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz deleted file mode 100644 index cbfa3779d..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json deleted file mode 100644 index a5692a66e..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_labels/nasa_tropical_storm_competition_test_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz deleted file mode 100644 index 7a8162faf..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json deleted file mode 100644 index 97c44e990..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_test_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_test_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json deleted file mode 100644 index 83438ddff..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json deleted file mode 100644 index 13f4a63af..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json deleted file mode 100644 index d8671e264..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json deleted file mode 100644 index a6eebd660..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json deleted file mode 100644 index 90267dc6f..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz deleted file mode 100644 index 83f913867..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json deleted file mode 100644 index 834d29399..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_labels_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_labels_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_a_000/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_b_001/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_c_002/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_d_003/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json deleted file mode 100644 index e59bae96d..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_labels/nasa_tropical_storm_competition_train_labels_e_004/labels.json +++ /dev/null @@ -1 +0,0 @@ -{"wind_speed": "34"} \ No newline at end of file diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz b/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz deleted file mode 100644 index b3f019e97..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source.tar.gz and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json deleted file mode 100644 index a03e0c77a..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/collection.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "links": [ - { - "href": "nasa_tropical_storm_competition_train_source_a_000/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_b_001/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_c_002/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_d_003/stac.json", - "rel": "item" - }, - { - "href": "nasa_tropical_storm_competition_train_source_e_004/stac.json", - "rel": "item" - } - ] -} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json deleted file mode 100644 index 83438ddff..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "a", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg deleted file mode 100644 index 79c38f2a9..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_a_000/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json deleted file mode 100644 index 13f4a63af..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "b", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json deleted file mode 100644 index d8671e264..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "c", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg deleted file mode 100644 index 79c38f2a9..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_c_002/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json deleted file mode 100644 index a6eebd660..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "d", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg deleted file mode 100644 index 79c38f2a9..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_d_003/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json deleted file mode 100644 index 90267dc6f..000000000 --- a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/features.json +++ /dev/null @@ -1 +0,0 @@ -{"storm_id": "e", "relative_time": "0", "ocean": "2"} diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg b/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg deleted file mode 100644 index 79c38f2a9..000000000 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_e_004/image.jpg and /dev/null differ diff --git a/tests/data/cyclone/test/aaa_000.jpg b/tests/data/cyclone/test/aaa_000.jpg new file mode 100644 index 000000000..f4d039da9 Binary files /dev/null and b/tests/data/cyclone/test/aaa_000.jpg differ diff --git a/tests/data/cyclone/test/bbb_111.jpg b/tests/data/cyclone/test/bbb_111.jpg new file mode 100644 index 000000000..0d8e7a84a Binary files /dev/null and b/tests/data/cyclone/test/bbb_111.jpg differ diff --git a/tests/data/cyclone/test/ccc_222.jpg b/tests/data/cyclone/test/ccc_222.jpg new file mode 100644 index 000000000..ebd3ba67c Binary files /dev/null and b/tests/data/cyclone/test/ccc_222.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg b/tests/data/cyclone/test/ddd_333.jpg similarity index 74% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg rename to tests/data/cyclone/test/ddd_333.jpg index 79c38f2a9..575d5a5c6 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_c_002/image.jpg and b/tests/data/cyclone/test/ddd_333.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg b/tests/data/cyclone/test/eee_444.jpg similarity index 82% rename from tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg rename to tests/data/cyclone/test/eee_444.jpg index 77c95fe87..0cd10728e 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_train_source/nasa_tropical_storm_competition_train_source_b_001/image.jpg and b/tests/data/cyclone/test/eee_444.jpg differ diff --git a/tests/data/cyclone/test_set_features.csv b/tests/data/cyclone/test_set_features.csv new file mode 100644 index 000000000..dce291b0b --- /dev/null +++ b/tests/data/cyclone/test_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +aaa_000,aaa,0,0 +bbb_111,bbb,1,1 +ccc_222,ccc,2,2 +ddd_333,ddd,3,3 +eee_444,eee,4,4 diff --git a/tests/data/cyclone/test_set_labels.csv b/tests/data/cyclone/test_set_labels.csv new file mode 100644 index 000000000..8aa2d7c7f --- /dev/null +++ b/tests/data/cyclone/test_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +aaa_000,0 +bbb_111,1 +ccc_222,2 +ddd_333,3 +eee_444,4 diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg b/tests/data/cyclone/train/fff_555.jpg similarity index 73% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg rename to tests/data/cyclone/train/fff_555.jpg index 79c38f2a9..15225859b 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_d_003/image.jpg and b/tests/data/cyclone/train/fff_555.jpg differ diff --git a/tests/data/cyclone/train/ggg_666.jpg b/tests/data/cyclone/train/ggg_666.jpg new file mode 100644 index 000000000..3065b52a8 Binary files /dev/null and b/tests/data/cyclone/train/ggg_666.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg b/tests/data/cyclone/train/hhh_777.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg rename to tests/data/cyclone/train/hhh_777.jpg index 79c38f2a9..877ac76c4 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_e_004/image.jpg and b/tests/data/cyclone/train/hhh_777.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg b/tests/data/cyclone/train/iii_888.jpg similarity index 82% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg rename to tests/data/cyclone/train/iii_888.jpg index 77c95fe87..731128b8a 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_b_001/image.jpg and b/tests/data/cyclone/train/iii_888.jpg differ diff --git a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg b/tests/data/cyclone/train/jjj_999.jpg similarity index 75% rename from tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg rename to tests/data/cyclone/train/jjj_999.jpg index 79c38f2a9..8fda5ace9 100644 Binary files a/tests/data/cyclone/nasa_tropical_storm_competition_test_source/nasa_tropical_storm_competition_test_source_a_000/image.jpg and b/tests/data/cyclone/train/jjj_999.jpg differ diff --git a/tests/data/cyclone/training_set_features.csv b/tests/data/cyclone/training_set_features.csv new file mode 100644 index 000000000..56df786e8 --- /dev/null +++ b/tests/data/cyclone/training_set_features.csv @@ -0,0 +1,6 @@ +Image ID,Storm ID,Relative Time,Ocean +fff_555,fff,5,5 +ggg_666,ggg,6,6 +hhh_777,hhh,7,7 +iii_888,iii,8,8 +jjj_999,jjj,9,9 diff --git a/tests/data/cyclone/training_set_labels.csv b/tests/data/cyclone/training_set_labels.csv new file mode 100644 index 000000000..5a8bbabce --- /dev/null +++ b/tests/data/cyclone/training_set_labels.csv @@ -0,0 +1,6 @@ +Image ID,Wind Speed +fff_555,5 +ggg_666,6 +hhh_777,7 +iii_888,8 +jjj_999,9 diff --git a/tests/datasets/test_cyclone.py b/tests/datasets/test_cyclone.py index d165b064a..bb18bed06 100644 --- a/tests/datasets/test_cyclone.py +++ b/tests/datasets/test_cyclone.py @@ -1,9 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import glob import os -import shutil from pathlib import Path import matplotlib.pyplot as plt @@ -15,52 +13,33 @@ from pytest import MonkeyPatch from torch.utils.data import ConcatDataset from torchgeo.datasets import DatasetNotFoundError, TropicalCyclone - - -class Collection: - def download(self, output_dir: str, **kwargs: str) -> None: - for tarball in glob.iglob(os.path.join('tests', 'data', 'cyclone', '*.tar.gz')): - shutil.copy(tarball, output_dir) - - -def fetch(collection_id: str, **kwargs: str) -> Collection: - return Collection() +from torchgeo.datasets.utils import Executable class TestTropicalCyclone: @pytest.fixture(params=['train', 'test']) def dataset( - self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest + self, + request: SubRequest, + azcopy: Executable, + monkeypatch: MonkeyPatch, + tmp_path: Path, ) -> TropicalCyclone: - radiant_mlhub = pytest.importorskip('radiant_mlhub', minversion='0.3') - monkeypatch.setattr(radiant_mlhub.Collection, 'fetch', fetch) - md5s = { - 'train': { - 'source': '2b818e0a0873728dabf52c7054a0ce4c', - 'labels': 'c3c2b6d02c469c5519f4add4f9132712', - }, - 'test': { - 'source': 'bc07c519ddf3ce88857435ddddf98a16', - 'labels': '3ca4243eff39b87c73e05ec8db1824bf', - }, - } - monkeypatch.setattr(TropicalCyclone, 'md5s', md5s) - monkeypatch.setattr(TropicalCyclone, 'size', 1) + url = os.path.join('tests', 'data', 'cyclone') + monkeypatch.setattr(TropicalCyclone, 'url', url) + monkeypatch.setattr(TropicalCyclone, 'size', 2) root = str(tmp_path) split = request.param transforms = nn.Identity() - return TropicalCyclone( - root, split, transforms, download=True, api_key='', checksum=True - ) + return TropicalCyclone(root, split, transforms, download=True) @pytest.mark.parametrize('index', [0, 1]) def test_getitem(self, dataset: TropicalCyclone, index: int) -> None: x = dataset[index] assert isinstance(x, dict) assert isinstance(x['image'], torch.Tensor) - assert isinstance(x['storm_id'], str) - assert isinstance(x['relative_time'], int) - assert isinstance(x['ocean'], int) + assert isinstance(x['relative_time'], torch.Tensor) + assert isinstance(x['ocean'], torch.Tensor) assert isinstance(x['label'], torch.Tensor) assert x['image'].shape == (3, dataset.size, dataset.size) @@ -73,7 +52,7 @@ class TestTropicalCyclone: assert len(ds) == 10 def test_already_downloaded(self, dataset: TropicalCyclone) -> None: - TropicalCyclone(root=dataset.root, download=True, api_key='') + TropicalCyclone(root=dataset.root, download=True) def test_invalid_split(self) -> None: with pytest.raises(AssertionError): @@ -84,10 +63,9 @@ class TestTropicalCyclone: TropicalCyclone(str(tmp_path)) def test_plot(self, dataset: TropicalCyclone) -> None: - dataset.plot(dataset[0], suptitle='Test') - plt.close() - sample = dataset[0] + dataset.plot(sample, suptitle='Test') + plt.close() sample['prediction'] = sample['label'] dataset.plot(sample) plt.close() diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py index c53bfbed0..d6c9bc15c 100644 --- a/tests/datasets/test_utils.py +++ b/tests/datasets/test_utils.py @@ -597,7 +597,7 @@ def test_lazy_import_missing(name: str) -> None: def test_azcopy(tmp_path: Path, azcopy: Executable) -> None: source = os.path.join('tests', 'data', 'cyclone') azcopy('sync', source, tmp_path, '--recursive=true') - assert os.path.exists(tmp_path / 'nasa_tropical_storm_competition_test_labels') + assert os.path.exists(tmp_path / 'test') def test_which() -> None: diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 39021fc2a..e9af30209 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -43,18 +43,11 @@ class TropicalCycloneDataModule(NonGeoDataModule): stage: Either 'fit', 'validate', 'test', or 'predict'. """ if stage in ['fit', 'validate']: - self.dataset = TropicalCyclone(split='train', **self.kwargs) - - storm_ids = [] - for item in self.dataset.collection: - storm_id = item['href'].split('/')[0].split('_')[-2] - storm_ids.append(storm_id) - + dataset = TropicalCyclone(split='train', **self.kwargs) train_indices, val_indices = group_shuffle_split( - storm_ids, test_size=0.2, random_state=0 + dataset.features['Storm ID'], test_size=0.2, random_state=0 ) - - self.train_dataset = Subset(self.dataset, train_indices) - self.val_dataset = Subset(self.dataset, val_indices) + self.train_dataset = Subset(dataset, train_indices) + self.val_dataset = Subset(dataset, val_indices) if stage in ['test']: self.test_dataset = TropicalCyclone(split='test', **self.kwargs) diff --git a/torchgeo/datasets/cyclone.py b/torchgeo/datasets/cyclone.py index eccca9d73..747463b69 100644 --- a/torchgeo/datasets/cyclone.py +++ b/torchgeo/datasets/cyclone.py @@ -3,7 +3,6 @@ """Tropical Cyclone Wind Estimation Competition dataset.""" -import json import os from collections.abc import Callable from functools import lru_cache @@ -11,6 +10,7 @@ from typing import Any import matplotlib.pyplot as plt import numpy as np +import pandas as pd import torch from matplotlib.figure import Figure from PIL import Image @@ -18,7 +18,7 @@ from torch import Tensor from .errors import DatasetNotFoundError from .geo import NonGeoDataset -from .utils import check_integrity, download_radiant_mlhub_collection, extract_archive +from .utils import which class TropicalCyclone(NonGeoDataset): @@ -26,10 +26,9 @@ class TropicalCyclone(NonGeoDataset): A collection of tropical storms in the Atlantic and East Pacific Oceans from 2000 to 2019 with corresponding maximum sustained surface wind speed. This dataset is split - into training and test categories for the purpose of a competition. - - See https://www.drivendata.org/competitions/72/predict-wind-speeds/ for more - information about the competition. + into training and test categories for the purpose of a competition. Read more about + the competition here: + https://www.drivendata.org/competitions/72/predict-wind-speeds/. If you use this dataset in your research, please cite the following paper: @@ -39,31 +38,17 @@ class TropicalCyclone(NonGeoDataset): This dataset requires the following additional library to be installed: - * `radiant-mlhub `_ to download the - imagery and labels from the Radiant Earth MLHub + * `azcopy `_: to download the + dataset from Source Cooperative. .. versionchanged:: 0.4 Class name changed from TropicalCycloneWindEstimation to TropicalCyclone to be consistent with TropicalCycloneDataModule. """ - collection_id = 'nasa_tropical_storm_competition' - collection_ids = [ - 'nasa_tropical_storm_competition_train_source', - 'nasa_tropical_storm_competition_test_source', - 'nasa_tropical_storm_competition_train_labels', - 'nasa_tropical_storm_competition_test_labels', - ] - md5s = { - 'train': { - 'source': '97e913667a398704ea8d28196d91dad6', - 'labels': '97d02608b74c82ffe7496a9404a30413', - }, - 'test': { - 'source': '8d88099e4b310feb7781d776a6e1dcef', - 'labels': 'd910c430f90153c1f78a99cbc08e7bd0', - }, - } + url = ( + 'https://radiantearth.blob.core.windows.net/mlhub/nasa-tropical-storm-challenge' + ) size = 366 def __init__( @@ -72,10 +57,8 @@ class TropicalCyclone(NonGeoDataset): split: str = 'train', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, - api_key: str | None = None, - checksum: bool = False, ) -> None: - """Initialize a new Tropical Cyclone Wind Estimation Competition Dataset. + """Initialize a new TropicalCyclone instance. Args: root: root directory where dataset can be found @@ -83,30 +66,26 @@ class TropicalCyclone(NonGeoDataset): transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ - assert split in self.md5s + assert split in {'train', 'test'} self.root = root self.split = split self.transforms = transforms - self.checksum = checksum + self.download = download - if download: - self._download(api_key) + self.filename = f'{split}_set' + if split == 'train': + self.filename = f'{split}ing_set' - if not self._check_integrity(): - raise DatasetNotFoundError(self) + self._verify() - output_dir = '_'.join([self.collection_id, split, 'source']) - filename = os.path.join(root, output_dir, 'collection.json') - with open(filename) as f: - self.collection = json.load(f)['links'] + self.features = pd.read_csv(os.path.join(root, f'{self.filename}_features.csv')) + self.labels = pd.read_csv(os.path.join(root, f'{self.filename}_labels.csv')) def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. @@ -117,15 +96,14 @@ class TropicalCyclone(NonGeoDataset): Returns: data, labels, field ids, and metadata at that index """ - source_id = os.path.split(self.collection[index]['href'])[0] - directory = os.path.join( - self.root, - '_'.join([self.collection_id, self.split, '{0}']), - source_id.replace('source', '{0}'), - ) + sample = { + 'relative_time': torch.tensor(self.features.iat[index, 2]), + 'ocean': torch.tensor(self.features.iat[index, 3]), + 'label': torch.tensor(self.labels.iat[index, 1]), + } - sample: dict[str, Any] = {'image': self._load_image(directory)} - sample.update(self._load_features(directory)) + image_id = self.labels.iat[index, 0] + sample['image'] = self._load_image(image_id) if self.transforms is not None: sample = self.transforms(sample) @@ -138,19 +116,19 @@ class TropicalCyclone(NonGeoDataset): Returns: length of the dataset """ - return len(self.collection) + return len(self.labels) @lru_cache - def _load_image(self, directory: str) -> Tensor: + def _load_image(self, image_id: str) -> Tensor: """Load a single image. Args: - directory: directory containing image + image_id: Filename of the image. Returns: the image """ - filename = os.path.join(directory.format('source'), 'image.jpg') + filename = os.path.join(self.root, self.split, f'{image_id}.jpg') with Image.open(filename) as img: if img.height != self.size or img.width != self.size: # Moved in PIL 9.1.0 @@ -164,61 +142,30 @@ class TropicalCyclone(NonGeoDataset): tensor = tensor.permute((2, 0, 1)).float() return tensor - def _load_features(self, directory: str) -> dict[str, Any]: - """Load features for a single image. - - Args: - directory: directory containing image - - Returns: - the features - """ - filename = os.path.join(directory.format('source'), 'features.json') - with open(filename) as f: - features: dict[str, Any] = json.load(f) - - filename = os.path.join(directory.format('labels'), 'labels.json') - with open(filename) as f: - features.update(json.load(f)) - - features['relative_time'] = int(features['relative_time']) - features['ocean'] = int(features['ocean']) - features['label'] = torch.tensor(int(features['wind_speed'])).float() - - return features - - def _check_integrity(self) -> bool: - """Check integrity of dataset. - - Returns: - True if dataset files are found and/or MD5s match, else False - """ - for split, resources in self.md5s.items(): - for resource_type, md5 in resources.items(): - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename + '.tar.gz') - if not check_integrity(filename, md5 if self.checksum else None): - return False - return True - - def _download(self, api_key: str | None = None) -> None: - """Download the dataset and extract it. - - Args: - api_key: a RadiantEarth MLHub API key to use for downloading the dataset - """ - if self._check_integrity(): - print('Files already downloaded and verified') + def _verify(self) -> None: + """Verify the integrity of the dataset.""" + # Check if the files already exist + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + exists = [os.path.exists(os.path.join(self.root, file)) for file in files] + if all(exists): return - for collection_id in self.collection_ids: - download_radiant_mlhub_collection(collection_id, self.root, api_key) + # Check if the user requested to download the dataset + if not self.download: + raise DatasetNotFoundError(self) - for split, resources in self.md5s.items(): - for resource_type in resources: - filename = '_'.join([self.collection_id, split, resource_type]) - filename = os.path.join(self.root, filename) + '.tar.gz' - extract_archive(filename, self.root) + # Download the dataset + self._download() + + def _download(self) -> None: + """Download the dataset.""" + directory = os.path.join(self.root, self.split) + os.makedirs(directory, exist_ok=True) + azcopy = which('azcopy') + azcopy('sync', f'{self.url}/{self.split}', directory, '--recursive=true') + files = [f'{self.filename}_features.csv', f'{self.filename}_labels.csv'] + for file in files: + azcopy('copy', f'{self.url}/{file}', self.root) def plot( self,