зеркало из https://github.com/microsoft/torchgeo.git
Updates to SustainBenchCropYield dataset (#1756)
* Updating download, fixing __getitem__ speed * Formatting * Update torchgeo/datasets/sustainbench_crop_yield.py
This commit is contained in:
Родитель
348ae7ccf2
Коммит
b2369d6bd4
|
@ -47,9 +47,9 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
|
||||
valid_countries = ["usa", "brazil", "argentina"]
|
||||
|
||||
md5 = "c2794e59512c897d9bea77b112848122"
|
||||
md5 = "362bad07b51a1264172b8376b39d1fc9"
|
||||
|
||||
url = "https://drive.google.com/file/d/1odwkI1hiE5rMZ4VfM0hOXzlFR4NbhrfU/view?usp=share_link" # noqa: E501
|
||||
url = "https://drive.google.com/file/d/1lhbmICpmNuOBlaErywgiD6i9nHuhuv0A/view?usp=drive_link" # noqa: E501
|
||||
|
||||
dir = "soybeans"
|
||||
|
||||
|
@ -96,7 +96,38 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
self.checksum = checksum
|
||||
|
||||
self._verify()
|
||||
self.collection = self.retrieve_collection()
|
||||
|
||||
self.images = []
|
||||
self.features = []
|
||||
|
||||
for country in self.countries:
|
||||
image_file_path = os.path.join(
|
||||
self.root, self.dir, country, f"{self.split}_hists.npz"
|
||||
)
|
||||
target_file_path = image_file_path.replace("_hists", "_yields")
|
||||
years_file_path = image_file_path.replace("_hists", "_years")
|
||||
ndvi_file_path = image_file_path.replace("_hists", "_ndvi")
|
||||
|
||||
npz_file = np.load(image_file_path)["data"]
|
||||
target_npz_file = np.load(target_file_path)["data"]
|
||||
year_npz_file = np.load(years_file_path)["data"]
|
||||
ndvi_npz_file = np.load(ndvi_file_path)["data"]
|
||||
num_data_points = npz_file.shape[0]
|
||||
for idx in range(num_data_points):
|
||||
sample = npz_file[idx]
|
||||
sample = torch.from_numpy(sample).permute(2, 0, 1).to(torch.float32)
|
||||
self.images.append(sample)
|
||||
|
||||
target = target_npz_file[idx]
|
||||
year = year_npz_file[idx]
|
||||
ndvi = ndvi_npz_file[idx]
|
||||
|
||||
features = {
|
||||
"label": torch.tensor(target).to(torch.float32),
|
||||
"year": torch.tensor(int(year)),
|
||||
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
|
||||
}
|
||||
self.features.append(features)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of data points in the dataset.
|
||||
|
@ -104,7 +135,7 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
Returns:
|
||||
length of the dataset
|
||||
"""
|
||||
return len(self.collection)
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Tensor]:
|
||||
"""Return an index within the dataset.
|
||||
|
@ -115,76 +146,14 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
Returns:
|
||||
data and label at that index
|
||||
"""
|
||||
input_file_path, sample_idx = self.collection[index]
|
||||
|
||||
sample: dict[str, Tensor] = {
|
||||
"image": self._load_image(input_file_path, sample_idx)
|
||||
}
|
||||
sample.update(self._load_features(input_file_path, sample_idx))
|
||||
sample: dict[str, Tensor] = {"image": self.images[index]}
|
||||
sample.update(self.features[index])
|
||||
|
||||
if self.transforms is not None:
|
||||
sample = self.transforms(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def _load_image(self, path: str, sample_idx: int) -> Tensor:
|
||||
"""Load input image.
|
||||
|
||||
Args:
|
||||
path: path to input npz collection
|
||||
sample_idx: what sample to index from the npz collection
|
||||
|
||||
Returns:
|
||||
input image as tensor
|
||||
"""
|
||||
arr = np.load(path)["data"][sample_idx]
|
||||
# return [channel, height, width]
|
||||
return torch.from_numpy(arr).permute(2, 0, 1).to(torch.float32)
|
||||
|
||||
def _load_features(self, path: str, sample_idx: int) -> dict[str, Tensor]:
|
||||
"""Load features value.
|
||||
|
||||
Args:
|
||||
path: path to image npz collection
|
||||
sample_idx: what sample to index from the npz collection
|
||||
|
||||
Returns:
|
||||
target regression value
|
||||
"""
|
||||
target_file_path = path.replace("_hists", "_yields")
|
||||
target = np.load(target_file_path)["data"][sample_idx]
|
||||
|
||||
years_file_path = path.replace("_hists", "_years")
|
||||
year = int(np.load(years_file_path)["data"][sample_idx])
|
||||
|
||||
ndvi_file_path = path.replace("_hists", "_ndvi")
|
||||
ndvi = np.load(ndvi_file_path)["data"][sample_idx]
|
||||
|
||||
features = {
|
||||
"label": torch.tensor(target).to(torch.float32),
|
||||
"year": torch.tensor(year),
|
||||
"ndvi": torch.from_numpy(ndvi).to(dtype=torch.float32),
|
||||
}
|
||||
return features
|
||||
|
||||
def retrieve_collection(self) -> list[tuple[str, int]]:
|
||||
"""Retrieve the collection.
|
||||
|
||||
Returns:
|
||||
path and index to dataset samples
|
||||
"""
|
||||
collection = []
|
||||
for country in self.countries:
|
||||
file_path = os.path.join(
|
||||
self.root, self.dir, country, f"{self.split}_hists.npz"
|
||||
)
|
||||
npz_file = np.load(file_path)
|
||||
num_data_points = npz_file["data"].shape[0]
|
||||
for idx in range(num_data_points):
|
||||
collection.append((file_path, idx))
|
||||
|
||||
return collection
|
||||
|
||||
def _verify(self) -> None:
|
||||
"""Verify the integrity of the dataset.
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче