Rwanda Field Boundary: don't plot empty masks (#2254)

* Rwanda Field Boundary: don't plot empty masks

* Import sorting
This commit is contained in:
Adam J. Stewart 2024-08-27 16:39:11 +02:00 коммит произвёл GitHub
Родитель 451b5a5919
Коммит 0f37e8b85e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
1 изменённых файлов: 24 добавлений и 31 удалений

Просмотреть файл

@ -13,6 +13,7 @@ import numpy as np
import rasterio
import rasterio.features
import torch
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor
@ -181,41 +182,33 @@ class RwandaFieldBoundary(NonGeoDataset):
else:
raise RGBBandsMissingError()
num_time_points = sample['image'].shape[0]
assert time_step < num_time_points
ncols = 1
for key in ('mask', 'prediction'):
if key in sample:
ncols += 1
image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3)
image = np.clip(image / 2000, 0, 1)
fig, axs = plt.subplots(ncols=ncols, squeeze=False)
image = torch.clamp(sample['image'][time_step, rgb_indices] / 2000, 0, 1)
image = rearrange(image, 'c h w -> h w c')
axs[0, 0].imshow(image)
axs[0, 0].axis('off')
if show_titles:
axs[0, 0].set_title(f't={time_step}')
if 'mask' in sample:
mask = sample['mask'].numpy()
else:
mask = np.zeros_like(image)
num_panels = 2
showing_predictions = 'prediction' in sample
if showing_predictions:
predictions = sample['prediction'].numpy()
num_panels += 1
fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4))
axs[0].imshow(image)
axs[0].axis('off')
if show_titles:
axs[0].set_title(f't={time_step}')
axs[1].imshow(mask, vmin=0, vmax=1, interpolation='none')
axs[1].axis('off')
if show_titles:
axs[1].set_title('Mask')
if showing_predictions:
axs[2].imshow(predictions, vmin=0, vmax=1, interpolation='none')
axs[2].axis('off')
axs[0, 1].imshow(sample['mask'])
axs[0, 1].axis('off')
if show_titles:
axs[2].set_title('Predictions')
axs[0, 1].set_title('Mask')
if 'prediction' in sample:
axs[0, 2].imshow(sample['prediction'])
axs[0, 2].axis('off')
if show_titles:
axs[0, 2].set_title('Prediction')
if suptitle is not None:
plt.suptitle(suptitle)
fig.suptitle(suptitle)
return fig