Updates to SustainBenchCropYield dataset (#1756)

* Updating download, fixing __getitem__ speed

* Formatting

* Update torchgeo/datasets/sustainbench_crop_yield.py
This commit is contained in:
Caleb Robinson 2023-12-07 07:27:09 -08:00 коммит произвёл isaaccorley
Родитель 348ae7ccf2
Коммит b2369d6bd4
1 изменённых файлов: 37 добавлений и 68 удалений

Просмотреть файл

@ -47,9 +47,9 @@ class SustainBenchCropYield(NonGeoDataset):
valid_countries = ["usa", "brazil", "argentina"]
md5 = "c2794e59512c897d9bea77b112848122"
md5 = "362bad07b51a1264172b8376b39d1fc9"
url = "https://drive.google.com/file/d/1odwkI1hiE5rMZ4VfM0hOXzlFR4NbhrfU/view?usp=share_link" # noqa: E501
url = "https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link" # noqa: E501
dir = "soybeans"
@ -96,7 +96,38 @@ class SustainBenchCropYield(NonGeoDataset):
self.checksum = checksum
self._verify()
self.collection = self.retrieve_collection()
self.images = []
self.features = []
for country in self.countries:
image_file_path = os.path.join(
self.root, self.dir, country, f"{self.split}_hists.npz"
)
target_file_path = image_file_path.replace("_hists", "_yields")
years_file_path = image_file_path.replace("_hists", "_years")
ndvi_file_path = image_file_path.replace("_hists", "_ndvi")
npz_file = np.load(image_file_path)["data"]
target_npz_file = np.load(target_file_path)["data"]
year_npz_file = np.load(years_file_path)["data"]
ndvi_npz_file = np.load(ndvi_file_path)["data"]
num_data_points = npz_file.shape[0]
for idx in range(num_data_points):
sample = npz_file[idx]
sample = torch.from_numpy(sample).permute(2, 0, 1).to(torch.float32)
self.images.append(sample)
target = target_npz_file[idx]
year = year_npz_file[idx]
ndvi = ndvi_npz_file[idx]
features = {
"label": torch.tensor(target).to(torch.float32),
"year": torch.tensor(int(year)),
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
}
self.features.append(features)
def __len__(self) -> int:
"""Return the number of data points in the dataset.
@ -104,7 +135,7 @@ class SustainBenchCropYield(NonGeoDataset):
Returns:
length of the dataset
"""
return len(self.collection)
return len(self.images)
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""Return an index within the dataset.
@ -115,76 +146,14 @@ class SustainBenchCropYield(NonGeoDataset):
Returns:
data and label at that index
"""
input_file_path, sample_idx = self.collection[index]
sample: dict[str, Tensor] = {
"image": self._load_image(input_file_path, sample_idx)
}
sample.update(self._load_features(input_file_path, sample_idx))
sample: dict[str, Tensor] = {"image": self.images[index]}
sample.update(self.features[index])
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _load_image(self, path: str, sample_idx: int) -> Tensor:
"""Load input image.
Args:
path: path to input npz collection
sample_idx: what sample to index from the npz collection
Returns:
input image as tensor
"""
arr = np.load(path)["data"][sample_idx]
# return [channel, height, width]
return torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32)
def _load_features(self, path: str, sample_idx: int) -> dict[str, Tensor]:
"""Load features value.
Args:
path: path to image npz collection
sample_idx: what sample to index from the npz collection
Returns:
target regression value
"""
target_file_path = path.replace("_hists", "_yields")
target = np.load(target_file_path)["data"][sample_idx]
years_file_path = path.replace("_hists", "_years")
year = int(np.load(years_file_path)["data"][sample_idx])
ndvi_file_path = path.replace("_hists", "_ndvi")
ndvi = np.load(ndvi_file_path)["data"][sample_idx]
features = {
"label": torch.tensor(target).to(torch.float32),
"year": torch.tensor(year),
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
}
return features
def retrieve_collection(self) -> list[tuple[str, int]]:
"""Retrieve the collection.
Returns:
path and index to dataset samples
"""
collection = []
for country in self.countries:
file_path = os.path.join(
self.root, self.dir, country, f"{self.split}_hists.npz"
)
npz_file = np.load(file_path)
num_data_points = npz_file["data"].shape[0]
for idx in range(num_data_points):
collection.append((file_path, idx))
return collection
def _verify(self) -> None:
"""Verify the integrity of the dataset.