зеркало из https://github.com/microsoft/torchgeo.git
Rwanda Field Boundary: don't plot empty masks (#2254)
* Rwanda Field Boundary: don't plot empty masks * Import sorting
This commit is contained in:
Родитель
451b5a5919
Коммит
0f37e8b85e
|
@ -13,6 +13,7 @@ import numpy as np
|
|||
import rasterio
|
||||
import rasterio.features
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -181,41 +182,33 @@ class RwandaFieldBoundary(NonGeoDataset):
|
|||
else:
|
||||
raise RGBBandsMissingError()
|
||||
|
||||
num_time_points = sample['image'].shape[0]
|
||||
assert time_step < num_time_points
|
||||
ncols = 1
|
||||
for key in ('mask', 'prediction'):
|
||||
if key in sample:
|
||||
ncols += 1
|
||||
|
||||
image = np.rollaxis(sample['image'][time_step, rgb_indices].numpy(), 0, 3)
|
||||
image = np.clip(image / 2000, 0, 1)
|
||||
fig, axs = plt.subplots(ncols=ncols, squeeze=False)
|
||||
|
||||
image = torch.clamp(sample['image'][time_step, rgb_indices] / 2000, 0, 1)
|
||||
image = rearrange(image, 'c h w -> h w c')
|
||||
axs[0, 0].imshow(image)
|
||||
axs[0, 0].axis('off')
|
||||
if show_titles:
|
||||
axs[0, 0].set_title(f't={time_step}')
|
||||
|
||||
if 'mask' in sample:
|
||||
mask = sample['mask'].numpy()
|
||||
else:
|
||||
mask = np.zeros_like(image)
|
||||
|
||||
num_panels = 2
|
||||
showing_predictions = 'prediction' in sample
|
||||
if showing_predictions:
|
||||
predictions = sample['prediction'].numpy()
|
||||
num_panels += 1
|
||||
|
||||
fig, axs = plt.subplots(ncols=num_panels, figsize=(4 * num_panels, 4))
|
||||
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis('off')
|
||||
if show_titles:
|
||||
axs[0].set_title(f't={time_step}')
|
||||
|
||||
axs[1].imshow(mask, vmin=0, vmax=1, interpolation='none')
|
||||
axs[1].axis('off')
|
||||
if show_titles:
|
||||
axs[1].set_title('Mask')
|
||||
|
||||
if showing_predictions:
|
||||
axs[2].imshow(predictions, vmin=0, vmax=1, interpolation='none')
|
||||
axs[2].axis('off')
|
||||
axs[0, 1].imshow(sample['mask'])
|
||||
axs[0, 1].axis('off')
|
||||
if show_titles:
|
||||
axs[2].set_title('Predictions')
|
||||
axs[0, 1].set_title('Mask')
|
||||
|
||||
if 'prediction' in sample:
|
||||
axs[0, 2].imshow(sample['prediction'])
|
||||
axs[0, 2].axis('off')
|
||||
if show_titles:
|
||||
axs[0, 2].set_title('Prediction')
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
fig.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
|
Загрузка…
Ссылка в новой задаче