зеркало из https://github.com/microsoft/torchgeo.git
Add CropHarvest Dataset (#1677)
* initial commit * Added functionality to cropharvest dataset * Added test coverage * test fixes * mdpy typing * flake8 revision * added docs * fixed h5py import * fix .rst underline * updated tests to mock h5py module * fixed documentation * fixed black formating * turn labels to tensors * fix data generationa and mdpy for tensor encoding * update verify model * doc style * test coverage * fix test coverage leaks * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * update test data path and monkeypatch * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * remove hard coded classes * fixed plot and label one hot encoding * refactor datasetnotfounderror * resolve conflict with main * refactored importerror * mdpy * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Update torchgeo/datasets/cropharvest.py Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> * Remove empty class and correct csv * Update cropharvest.py * formatting changes --------- Co-authored-by: georgehuber <“georgehuber8@gmail.com”> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com> Co-authored-by: Caleb Robinson <calebrob6@gmail.com>
This commit is contained in:
Родитель
1c20678a10
Коммит
43d7133614
|
@ -206,6 +206,11 @@ COWC
|
|||
.. autoclass:: COWCCounting
|
||||
.. autoclass:: COWCDetection
|
||||
|
||||
CropHarvest
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: CropHarvest
|
||||
|
||||
Kenya Crop Type
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`ChaBuD`_,CD,Sentinel-2,"OpenRAIL",356,2,512x512,10,MSI
|
||||
`Cloud Cover Detection`_,S,Sentinel-2,"CC-BY-4.0","22,728",2,512x512,10,MSI
|
||||
`COWC`_,"C, R","CSUAV AFRL, ISPRS, LINZ, AGRC","AGPL-3.0-only","388,435",2,256x256,0.15,RGB
|
||||
`CropHarvest`_,"C","Sentinel-1/2, SRTM, ERA5","CC-BY-SA-4.0","70,213",351,1x1,10,"SAR, MSI, SRTM"
|
||||
`Kenya Crop Type`_,S,Sentinel-2,"CC-BY-SA-4.0","4,688",7,"3,035x2,016",10,MSI
|
||||
`DeepGlobe Land Cover`_,S,DigitalGlobe +Vivid,-,803,7,"2,448x2,448",0.5,RGB
|
||||
`DFC2022`_,S,Aerial,"CC-BY-4.0","3,981",15,"2,000x2,000",0.5,RGB
|
||||
|
@ -49,4 +50,4 @@ Dataset,Task,Source,License,# Samples,# Classes,Size (px),Resolution (m),Bands
|
|||
`VHR-10`_,I,"Google Earth, Vaihingen","MIT",800,10,"358--1,728",0.08--2,RGB
|
||||
`Western USA Live Fuel Moisture`_,R,"Landsat8, Sentinel-1","CC-BY-NC-ND-4.0",2615,-,-,-,-
|
||||
`xView2`_,CD,Maxar,"CC-BY-NC-SA-4.0","3,732",4,"1,024x1,024",0.8,RGB
|
||||
`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI
|
||||
`ZueriCrop`_,"I, T",Sentinel-2,-,116K,48,24x24,10,MSI
|
||||
|
|
|
|
@ -0,0 +1,153 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
SIZE = 32
|
||||
|
||||
np.random.seed(0)
|
||||
|
||||
PATHS = [
|
||||
os.path.join("cropharvest", "features", "arrays", "0_TestDataset1.h5"),
|
||||
os.path.join("cropharvest", "features", "arrays", "1_TestDataset1.h5"),
|
||||
os.path.join("cropharvest", "features", "arrays", "2_TestDataset1.h5"),
|
||||
os.path.join("cropharvest", "features", "arrays", "0_TestDataset2.h5"),
|
||||
os.path.join("cropharvest", "features", "arrays", "1_TestDataset2.h5"),
|
||||
]
|
||||
|
||||
|
||||
def create_geojson():
|
||||
geojson = {
|
||||
"type": "FeatureCollection",
|
||||
"crs": {},
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"dataset": "TestDataset1",
|
||||
"index": 0,
|
||||
"is_crop": 1,
|
||||
"label": "soybean",
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"dataset": "TestDataset1",
|
||||
"index": 0,
|
||||
"is_crop": 1,
|
||||
"label": "alfalfa",
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"dataset": "TestDataset1",
|
||||
"index": 1,
|
||||
"is_crop": 1,
|
||||
"label": None,
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"dataset": "TestDataset2",
|
||||
"index": 2,
|
||||
"is_crop": 1,
|
||||
"label": "maize",
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||
],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"dataset": "TestDataset2",
|
||||
"index": 1,
|
||||
"is_crop": 0,
|
||||
"label": None,
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
return geojson
|
||||
|
||||
|
||||
def create_file(path: str) -> None:
|
||||
Z = np.random.randint(4000, size=(12, 18), dtype=np.int64)
|
||||
with h5py.File(path, "w") as f:
|
||||
f.create_dataset("array", data=Z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
directory = "cropharvest"
|
||||
|
||||
# remove old data
|
||||
to_remove = [
|
||||
os.path.join(directory, "features"),
|
||||
os.path.join(directory, "features.tar.gz"),
|
||||
os.path.join(directory, "labels.geojson"),
|
||||
]
|
||||
for path in to_remove:
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
label_path = os.path.join(directory, "labels.geojson")
|
||||
geojson = create_geojson()
|
||||
os.makedirs(os.path.dirname(label_path), exist_ok=True)
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
json.dump(geojson, f)
|
||||
|
||||
for path in PATHS:
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
create_file(path)
|
||||
|
||||
# compress data
|
||||
source_dir = os.path.join(directory, "features")
|
||||
shutil.make_archive(source_dir, "gztar", directory, "features")
|
||||
|
||||
# compute checksum
|
||||
with open(label_path, "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"{label_path}: {md5}")
|
||||
|
||||
with open(os.path.join(directory, "features.tar.gz"), "rb") as f:
|
||||
md5 = hashlib.md5(f.read()).hexdigest()
|
||||
print(f"zipped features: {md5}")
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
{"type": "FeatureCollection", "crs": {}, "features": [{"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "soybean"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 0, "is_crop": 1, "label": "alfalfa"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset1", "index": 1, "is_crop": 1, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 2, "is_crop": 1, "label": "maize"}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}, {"type": "Feature", "properties": {"dataset": "TestDataset2", "index": 1, "is_crop": 0, "label": null}, "geometry": {"type": "Polygon", "coordinates": [[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [1.0, 0.0], [0.0, 0.0]]]}}]}
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import builtins
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pytest import MonkeyPatch
|
||||
|
||||
import torchgeo.datasets.utils
|
||||
from torchgeo.datasets import CropHarvest, DatasetNotFoundError
|
||||
|
||||
pytest.importorskip("h5py", minversion="3")
|
||||
|
||||
|
||||
def download_url(url: str, root: str, filename: str, md5: str) -> None:
|
||||
shutil.copy(url, os.path.join(root, filename))
|
||||
|
||||
|
||||
class TestCropHarvest:
|
||||
@pytest.fixture
|
||||
def mock_missing_module(self, monkeypatch: MonkeyPatch) -> None:
|
||||
import_orig = builtins.__import__
|
||||
|
||||
def mocked_import(name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
if name == "h5py":
|
||||
raise ImportError()
|
||||
return import_orig(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", mocked_import)
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(self, monkeypatch: MonkeyPatch, tmp_path: Path) -> CropHarvest:
|
||||
monkeypatch.setattr(torchgeo.datasets.cropharvest, "download_url", download_url)
|
||||
monkeypatch.setitem(
|
||||
CropHarvest.file_dict["features"], "md5", "ef6f4f00c0b3b50ed8380b0044928572"
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
CropHarvest.file_dict["labels"], "md5", "1d93b6bfcec7b6797b75acbd9d284b92"
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
CropHarvest.file_dict["features"],
|
||||
"url",
|
||||
os.path.join("tests", "data", "cropharvest", "features.tar.gz"),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
CropHarvest.file_dict["labels"],
|
||||
"url",
|
||||
os.path.join("tests", "data", "cropharvest", "labels.geojson"),
|
||||
)
|
||||
|
||||
root = str(tmp_path)
|
||||
transforms = nn.Identity()
|
||||
|
||||
dataset = CropHarvest(root, transforms, download=True, checksum=True)
|
||||
return dataset
|
||||
|
||||
def test_getitem(self, dataset: CropHarvest) -> None:
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert isinstance(x["array"], torch.Tensor)
|
||||
assert isinstance(x["label"], torch.Tensor)
|
||||
assert x["array"].shape == (12, 18)
|
||||
y = dataset[2]
|
||||
assert y["label"] == 1
|
||||
|
||||
def test_len(self, dataset: CropHarvest) -> None:
|
||||
assert len(dataset) == 5
|
||||
|
||||
def test_already_downloaded(self, dataset: CropHarvest, tmp_path: Path) -> None:
|
||||
CropHarvest(root=str(tmp_path), download=False)
|
||||
|
||||
def test_downloaded_zipped(self, dataset: CropHarvest, tmp_path: Path) -> None:
|
||||
feature_path = os.path.join(tmp_path, "features")
|
||||
shutil.rmtree(feature_path)
|
||||
CropHarvest(root=str(tmp_path), download=True)
|
||||
|
||||
def test_not_downloaded(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(DatasetNotFoundError, match="Dataset not found"):
|
||||
CropHarvest(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: CropHarvest) -> None:
|
||||
x = dataset[0].copy()
|
||||
dataset.plot(x, subtitle="Test")
|
||||
plt.close()
|
||||
|
||||
def test_mock_missing_module(
|
||||
self, dataset: CropHarvest, tmp_path: Path, mock_missing_module: None
|
||||
) -> None:
|
||||
with pytest.raises(
|
||||
ImportError,
|
||||
match="h5py is not installed and is required to use this dataset",
|
||||
):
|
||||
CropHarvest(root=str(tmp_path), download=True)[0]
|
|
@ -28,6 +28,7 @@ from .chesapeake import (
|
|||
from .cloud_cover import CloudCoverDetection
|
||||
from .cms_mangrove_canopy import CMSGlobalMangroveCanopy
|
||||
from .cowc import COWC, COWCCounting, COWCDetection
|
||||
from .cropharvest import CropHarvest
|
||||
from .cv4a_kenya_crop_type import CV4AKenyaCropType
|
||||
from .cyclone import TropicalCyclone
|
||||
from .deepglobelandcover import DeepGlobeLandCover
|
||||
|
@ -150,6 +151,7 @@ __all__ = (
|
|||
"ChesapeakeWV",
|
||||
"ChesapeakeCVPR",
|
||||
"CMSGlobalMangroveCanopy",
|
||||
"CropHarvest",
|
||||
"EDDMapS",
|
||||
"Esri2020",
|
||||
"EUDEM",
|
||||
|
|
|
@ -0,0 +1,318 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""CropHarvest datasets."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
from .utils import DatasetNotFoundError, download_url, extract_archive
|
||||
|
||||
|
||||
class CropHarvest(NonGeoDataset):
|
||||
"""CropHarvest dataset.
|
||||
|
||||
`CropHarvest <https://github.com/nasaharvest/cropharvest>`__ is a
|
||||
crop classification dataset.
|
||||
|
||||
Dataset features:
|
||||
|
||||
* single pixel time series with crop-type labels
|
||||
* 18 bands per image over 12 months
|
||||
|
||||
Dataset format:
|
||||
|
||||
* arrays are 12x18 with 18 bands over 12 months
|
||||
|
||||
Dataset properties:
|
||||
|
||||
1. is_crop - whether or not a single pixel contains cropland
|
||||
2. classification_label - optional field identifying a specific crop type
|
||||
3. dataset - source dataset for the imagery
|
||||
4. lat - latitude
|
||||
5. lon - longitude
|
||||
|
||||
If you use this dataset in your research, please cite the following paper:
|
||||
|
||||
* https://openreview.net/forum?id=JtjzUXPEaCu
|
||||
|
||||
This dataset requires the following additional library to be installed:
|
||||
|
||||
* `h5py <https://pypi.org/project/h5py/>`_ to load the dataset
|
||||
|
||||
.. versionadded:: 0.6
|
||||
"""
|
||||
|
||||
# https://github.com/nasaharvest/cropharvest/blob/main/cropharvest/bands.py
|
||||
all_bands = [
|
||||
"VV",
|
||||
"VH",
|
||||
"B2",
|
||||
"B3",
|
||||
"B4",
|
||||
"B5",
|
||||
"B6",
|
||||
"B7",
|
||||
"B8",
|
||||
"B8A",
|
||||
"B9",
|
||||
"B11",
|
||||
"B12",
|
||||
"temperature_2m",
|
||||
"total_precipitation",
|
||||
"elevation",
|
||||
"slope",
|
||||
"NDVI",
|
||||
]
|
||||
rgb_bands = ["B4", "B3", "B2"]
|
||||
|
||||
features_url = "https://zenodo.org/records/7257688/files/features.tar.gz?download=1"
|
||||
labels_url = "https://zenodo.org/records/7257688/files/labels.geojson?download=1"
|
||||
file_dict = {
|
||||
"features": {
|
||||
"url": features_url,
|
||||
"filename": "features.tar.gz",
|
||||
"extracted_filename": os.path.join("features", "arrays"),
|
||||
"md5": "cad4df655c75caac805a80435e46ee3e",
|
||||
},
|
||||
"labels": {
|
||||
"url": labels_url,
|
||||
"filename": "labels.geojson",
|
||||
"extracted_filename": "labels.geojson",
|
||||
"md5": "bf7bae6812fc7213481aff6a2e34517d",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: str = "data",
|
||||
transforms: Optional[Callable[[dict[str, Tensor]], dict[str, Tensor]]] = None,
|
||||
download: bool = False,
|
||||
checksum: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a new CropHarvest dataset instance.
|
||||
|
||||
Args:
|
||||
root: root directory where dataset can be found
|
||||
transforms: a function/transform that takes input sample and its target as
|
||||
entry and returns a transformed version
|
||||
download: if True, download dataset and store it in the root directory
|
||||
checksum: if True, check the MD5 of the downloaded files (may be slow)
|
||||
|
||||
Raises:
|
||||
DatasetNotFoundError: If dataset is not found and *download* is False.
|
||||
ImportError: If h5py is not installed
|
||||
"""
|
||||
try:
|
||||
import h5py # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"h5py is not installed and is required to use this dataset"
|
||||
)
|
||||
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
self.checksum = checksum
|
||||
self.download = download
|
||||
|
||||
self._verify()
|
||||
|
||||
self.files = self._load_features(self.root)
|
||||
self.labels = self._load_labels(self.root)
|
||||
self.classes = self.labels["properties.label"].unique()
|
||||
self.classes = self.classes[self.classes != np.array(None)]
|
||||
self.classes = np.insert(self.classes, 0, ["None", "Other"])
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
||||
Args:
|
||||
index: index to return
|
||||
|
||||
Returns:
|
||||
single pixel time-series array and label at that index
|
||||
"""
|
||||
files = self.files[index]
|
||||
data = self._load_array(files["chip"])
|
||||
|
||||
label = self._load_label(files["index"], files["dataset"])
|
||||
sample = {"array": data, "label": label}
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
||||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.files)
|
||||
|
||||
def _load_features(self, root: str) -> list[dict[str, str]]:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
list of dicts containing path for each of hd5 single pixel time series and
|
||||
its key for associated data
|
||||
"""
|
||||
files = []
|
||||
chips = glob.glob(
|
||||
os.path.join(root, self.file_dict["features"]["extracted_filename"], "*.h5")
|
||||
)
|
||||
chips = sorted(os.path.basename(chip) for chip in chips)
|
||||
for chip in chips:
|
||||
chip_path = os.path.join(
|
||||
root, self.file_dict["features"]["extracted_filename"], chip
|
||||
)
|
||||
index = chip.split("_")[0]
|
||||
dataset = chip.split("_")[1][:-3]
|
||||
files.append(dict(chip=chip_path, index=index, dataset=dataset))
|
||||
return files
|
||||
|
||||
def _load_labels(self, root: str) -> pd.DataFrame:
|
||||
"""Return the paths of the files in the dataset.
|
||||
|
||||
Args:
|
||||
root: root dir of dataset
|
||||
|
||||
Returns:
|
||||
pandas dataframe containing label data for each feature
|
||||
"""
|
||||
filename = self.file_dict["labels"]["extracted_filename"]
|
||||
with open(os.path.join(root, filename), encoding="utf8") as f:
|
||||
data = json.load(f)
|
||||
df = pd.json_normalize(data["features"])
|
||||
return df
|
||||
|
||||
def _load_array(self, path: str) -> Tensor:
|
||||
"""Load an individual single pixel time series.
|
||||
|
||||
Args:
|
||||
path: path to the image
|
||||
|
||||
Returns:
|
||||
the image
|
||||
"""
|
||||
import h5py
|
||||
|
||||
filename = os.path.join(path)
|
||||
with h5py.File(filename, "r") as f:
|
||||
array = f.get("array")[()]
|
||||
tensor = torch.from_numpy(array)
|
||||
return tensor
|
||||
|
||||
def _load_label(self, idx: str, dataset: str) -> Tensor:
|
||||
"""Load the crop-type label for a single pixel time series.
|
||||
|
||||
Args:
|
||||
idx: sample index in labels.geojson
|
||||
dataset: dataset name to query labels.geojson
|
||||
|
||||
Returns:
|
||||
the crop-type label
|
||||
"""
|
||||
index = int(idx)
|
||||
row = self.labels[
|
||||
(self.labels["properties.index"] == index)
|
||||
& (self.labels["properties.dataset"] == dataset)
|
||||
]
|
||||
row = row.to_dict(orient="records")[0]
|
||||
label = "None"
|
||||
if row["properties.label"]:
|
||||
label = row["properties.label"]
|
||||
elif row["properties.is_crop"] == 1:
|
||||
label = "Other"
|
||||
|
||||
return torch.tensor(np.where(self.classes == label)[0][0])
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset."""
|
||||
# Check if feature files already exist
|
||||
feature_path = os.path.join(
|
||||
self.root, self.file_dict["features"]["extracted_filename"]
|
||||
)
|
||||
feature_path_zip = os.path.join(
|
||||
self.root, self.file_dict["features"]["filename"]
|
||||
)
|
||||
label_path = os.path.join(
|
||||
self.root, self.file_dict["labels"]["extracted_filename"]
|
||||
)
|
||||
# Check if labels exist
|
||||
if os.path.exists(label_path):
|
||||
# Check if features exist
|
||||
if os.path.exists(feature_path):
|
||||
return
|
||||
# Check if features are downloaded in zip format
|
||||
if os.path.exists(feature_path_zip):
|
||||
self._extract()
|
||||
return
|
||||
|
||||
# Check if the user requested to download the dataset
|
||||
if not self.download:
|
||||
raise DatasetNotFoundError(self)
|
||||
|
||||
# Download and extract the dataset
|
||||
self._download()
|
||||
self._extract()
|
||||
|
||||
def _download(self) -> None:
|
||||
"""Download the dataset and extract it."""
|
||||
features_path = os.path.join(self.file_dict["features"]["filename"])
|
||||
download_url(
|
||||
self.file_dict["features"]["url"],
|
||||
self.root,
|
||||
filename=features_path,
|
||||
md5=self.file_dict["features"]["md5"] if self.checksum else None,
|
||||
)
|
||||
|
||||
download_url(
|
||||
self.file_dict["labels"]["url"],
|
||||
self.root,
|
||||
filename=os.path.join(self.file_dict["labels"]["filename"]),
|
||||
md5=self.file_dict["labels"]["md5"] if self.checksum else None,
|
||||
)
|
||||
|
||||
def _extract(self) -> None:
|
||||
"""Extract the dataset."""
|
||||
features_path = os.path.join(self.root, self.file_dict["features"]["filename"])
|
||||
extract_archive(features_path)
|
||||
|
||||
def plot(self, sample: dict[str, Tensor], subtitle: Optional[str] = None) -> Figure:
|
||||
"""Plot a sample from the dataset using bands for Agriculture RGB composite.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`__getitem__`
|
||||
suptitle: optional subtitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
"""
|
||||
fig, axs = plt.subplots()
|
||||
bands = [self.all_bands.index(band) for band in self.rgb_bands]
|
||||
rgb = np.array(sample["array"])[:, bands] / 3000
|
||||
axs.imshow(rgb[None, ...])
|
||||
axs.set_title(f'Crop type: {self.classes[sample["label"]]}')
|
||||
axs.set_xticks(np.arange(12))
|
||||
axs.set_xticklabels(np.arange(12) + 1)
|
||||
axs.set_yticks([])
|
||||
axs.set_xlabel("Month")
|
||||
if subtitle is not None:
|
||||
plt.suptitle(subtitle)
|
||||
|
||||
return fig
|
Загрузка…
Ссылка в новой задаче