Bump matplotlib from 3.7.3 to 3.8.0 in /requirements (#1566)

* Bump matplotlib from 3.7.3 to 3.8.0 in /requirements

Bumps [matplotlib](https://github.com/matplotlib/matplotlib) from 3.7.3 to 3.8.0.
- [Release notes](https://github.com/matplotlib/matplotlib/releases)
- [Commits](https://github.com/matplotlib/matplotlib/compare/v3.7.3...v3.8.0)

---
updated-dependencies:
- dependency-name: matplotlib
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Use canonical import location

* Use canonical import location

* Use canonical import location

* More type fixes

* More type fixes

* More type fixes

* More type fixes

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
dependabot[bot] 2023-09-21 11:35:57 +00:00 коммит произвёл GitHub
Родитель f641d075f9
Коммит 2fbdc85efd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
61 изменённых файлов: 148 добавлений и 97 удалений

Просмотреть файл

@ -46,7 +46,9 @@ df = pd.read_csv("band_data.csv", skip_blank_lines=True)
df = df.iloc[::-1]
fig, ax = plt.subplots(figsize=(5.5, args.fig_height))
ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]})
ax1, ax2 = fig.subplots(
nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]}
) # type: ignore[misc]
sensor_names: list[str] = []
sensor_ylocs: list[float] = []
@ -161,4 +163,4 @@ ax2.plot(0, 0, transform=ax2.transAxes, **kwargs)
plt.tight_layout()
plt.subplots_adjust(wspace=0.05)
plt.show()
plt.show() # type: ignore[no-untyped-call]

Просмотреть файл

@ -74,7 +74,7 @@ global_xmax = date.today()
fig, ax = plt.subplots(figsize=(5.5, 3))
cmap = iter(plt.cm.tab10(range(9, 0, -1)))
cmap = iter(plt.cm.tab10(range(9, 0, -1))) # type: ignore[attr-defined]
ymin = args.bar_start
yticks = []
for satellite in range(9, 0, -1):
@ -141,4 +141,4 @@ ax.tick_params(axis="both", which="both", top=False, right=False)
ax.spines[["top", "right"]].set_visible(False)
plt.tight_layout()
plt.show()
plt.show() # type: ignore[no-untyped-call]

Просмотреть файл

@ -74,4 +74,4 @@ ax.legend(fontsize="large")
plt.gca().spines.right.set_visible(False)
plt.gca().spines.top.set_visible(False)
plt.tight_layout()
plt.show()
plt.show() # type: ignore[no-untyped-call]

Просмотреть файл

@ -32,7 +32,7 @@ for i, (label, df) in enumerate(other):
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
ax.set_xticklabels(["16", "32", "64", "128", "256"], fontsize=12)
ax.set_xlabel("batch size", fontsize=12)
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
ax.legend(loc="center right", fontsize="large")
@ -40,4 +40,4 @@ ax.legend(loc="center right", fontsize="large")
plt.gca().spines.right.set_visible(False)
plt.gca().spines.top.set_visible(False)
plt.tight_layout()
plt.show()
plt.show() # type: ignore[no-untyped-call]

Просмотреть файл

@ -49,8 +49,8 @@ for i, (label, df) in enumerate(other):
ax.set_xscale("log")
ax.set_xticks([16, 32, 64, 128, 256])
ax.set_xticklabels([16, 32, 64, 128, 256])
ax.set_xticklabels(["16", "32", "64", "128", "256"])
ax.set_xlabel("batch size")
ax.set_ylabel("% sampling rate (patches/sec)")
ax.legend()
plt.show()
plt.show() # type: ignore[no-untyped-call]

Просмотреть файл

@ -7,7 +7,7 @@ fiona==1.9.4.post1
kornia==0.7.0
lightly==1.4.19
lightning==2.0.9
matplotlib==3.7.3
matplotlib==3.8.0
numpy==1.26.0
pillow==10.0.1
pyproj==3.6.0

Просмотреть файл

@ -8,6 +8,7 @@ import pytest
import torch
from _pytest.fixtures import SubRequest
from lightning.pytorch import Trainer
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@ -33,7 +34,7 @@ class CustomGeoDataset(GeoDataset):
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()
@ -72,7 +73,7 @@ class CustomNonGeoDataset(NonGeoDataset):
def __len__(self) -> int:
return self.length
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
def plot(self, *args: Any, **kwargs: Any) -> Figure:
return plt.figure()

Просмотреть файл

@ -6,9 +6,9 @@
from typing import Any, Callable, Optional, Union, cast
import kornia.augmentation as K
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import LightningDataModule
from matplotlib.figure import Figure
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, default_collate
@ -141,7 +141,7 @@ class BaseDataModule(LightningDataModule):
return batch
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]:
"""Run the plot method of the validation dataset if one exists.
Should only be called during 'fit' or 'validate' stages as ``val_dataset``
@ -154,10 +154,12 @@ class BaseDataModule(LightningDataModule):
Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
fig: Optional[Figure] = None
dataset = self.dataset or self.val_dataset
if dataset is not None:
if hasattr(dataset, "plot"):
return dataset.plot(*args, **kwargs)
fig = dataset.plot(*args, **kwargs)
return fig
class GeoDataModule(BaseDataModule):

Просмотреть файл

@ -6,7 +6,7 @@
from typing import Any, Optional, Union
import kornia.augmentation as K
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from ..datasets import NAIP, BoundingBox, Chesapeake13
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
@ -95,7 +95,7 @@ class NAIPChesapeakeDataModule(GeoDataModule):
self.dataset, self.patch_size, self.patch_size, test_roi
)
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run NAIP plot method.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -236,7 +237,7 @@ class ADVANCE(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -129,7 +130,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -97,7 +98,7 @@ class AsterGDEM(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -13,6 +13,7 @@ import numpy as np
import rasterio
import rasterio.features
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@ -431,7 +432,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
@ -533,7 +534,7 @@ class BigEarthNet(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -7,6 +7,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import VectorDataset
@ -127,7 +128,7 @@ class CanadianBuildingFootprints(VectorDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -347,7 +348,7 @@ class CDL(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -19,6 +19,7 @@ import shapely.geometry
import shapely.ops
import torch
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@ -173,7 +174,7 @@ class Chesapeake(RasterDataset, abc.ABC):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:
@ -743,7 +744,7 @@ class ChesapeakeCVPR(GeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -355,7 +356,7 @@ class CloudCoverDetection(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -256,7 +257,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -196,7 +197,7 @@ class COWC(NonGeoDataset, abc.ABC):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -411,7 +412,7 @@ class CV4AKenyaCropType(NonGeoDataset):
show_titles: bool = True,
time_step: int = 0,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -227,7 +228,7 @@ class TropicalCyclone(NonGeoDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -13,6 +13,7 @@ import numpy as np
import rasterio
import torch
from matplotlib import colors
from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
@ -298,7 +299,7 @@ class DFC2022(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -18,6 +18,7 @@ import shapely.geometry
import shapely.ops
import torch
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import GeoDataset
@ -454,7 +455,7 @@ class EnviroAtlas(GeoDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Note: only plots the "naip" and "lc" layers.

Просмотреть файл

@ -8,6 +8,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -142,7 +143,7 @@ class Esri2020(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -263,7 +264,7 @@ class ETCI2021(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ import os
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -144,7 +145,7 @@ class EUDEM(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
@ -261,7 +262,7 @@ class EuroSAT(NonGeoClassificationDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -395,7 +396,7 @@ class FAIR1M(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -7,6 +7,7 @@ import os
from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
@ -144,7 +145,7 @@ class FireRisk(NonGeoClassificationDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -263,7 +264,7 @@ class ForestDamage(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -241,9 +242,7 @@ class GID15(NonGeoDataset):
md5=self.md5 if self.checksum else None,
)
def plot(
self, sample: dict[str, Tensor], suptitle: Optional[str] = None
) -> plt.Figure:
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, cast
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -230,7 +231,7 @@ class GlobBiomass(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from rasterio.enums import Resampling
from torch import Tensor
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
@ -496,7 +497,7 @@ class IDTReeS(NonGeoDataset):
show_titles: bool = True,
suptitle: Optional[str] = None,
hsi_indices: tuple[int, int, int] = (0, 1, 2),
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@ -223,7 +224,7 @@ class L7Irish(RasterDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from torch import Tensor
@ -219,7 +220,7 @@ class L8Biome(RasterDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -13,6 +13,7 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from PIL import Image
from rasterio.crs import CRS
from torch import Tensor
@ -155,7 +156,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ from collections.abc import Sequence
from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -90,7 +91,7 @@ class Landsat(RasterDataset, abc.ABC):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -213,7 +214,7 @@ class LEVIRCDPlus(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -264,9 +265,7 @@ class LoveDA(NonGeoDataset):
md5=self.md5 if self.checksum else None,
)
def plot(
self, sample: dict[str, Tensor], suptitle: Optional[str] = None
) -> plt.Figure:
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -336,7 +337,7 @@ class MillionAID(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -6,6 +6,7 @@
from typing import Any, Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from .geo import RasterDataset
@ -52,7 +53,7 @@ class NAIP(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from torchvision.utils import draw_bounding_boxes
@ -224,7 +225,7 @@ class NASAMarineDebris(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -245,7 +246,7 @@ class NLCD(RasterDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -16,6 +16,7 @@ import rasterio
import shapely
import shapely.wkt as wkt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from rtree.index import Index, Property
@ -434,7 +435,7 @@ class OpenBuildings(VectorDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -351,7 +352,7 @@ class PASTIS(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -7,6 +7,7 @@ import os
from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
@ -150,7 +151,7 @@ class PatternNet(NonGeoClassificationDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -223,7 +224,7 @@ class ReforesTree(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
@ -246,7 +247,7 @@ class RESISC45(NonGeoClassificationDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -232,7 +233,7 @@ class SeasonalContrastS2(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -317,7 +318,7 @@ class SEN12MS(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from rasterio.crs import CRS
from .geo import RasterDataset
@ -190,7 +191,7 @@ To create a dataset containing both, use:
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:
@ -326,7 +327,7 @@ class Sentinel2(Sentinel):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -226,7 +227,7 @@ class SKIPPD(NonGeoDataset):
sample: dict[str, Any],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -335,7 +336,7 @@ class So2Sat(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -293,7 +294,7 @@ class SSL4EOL(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:
@ -512,7 +513,7 @@ class SSL4EOS12(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .cdl import CDL
@ -332,7 +333,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -238,7 +239,7 @@ class SustainBenchCropYield(NonGeoDataset):
band_idx: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -8,6 +8,7 @@ from typing import Callable, Optional, cast
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms.functional as F
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoClassificationDataset
@ -222,7 +223,7 @@ class UCMerced(NonGeoClassificationDataset):
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import patches
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -371,7 +372,7 @@ class VHR10(NonGeoDataset):
show_feats: Optional[str] = "both",
box_alpha: float = 0.7,
mask_alpha: float = 0.7,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:
@ -394,13 +395,13 @@ class VHR10(NonGeoDataset):
assert show_feats in {"boxes", "masks", "both"}
if self.split == "negative":
plt.imshow(sample["image"].permute(1, 2, 0))
axs = plt.gca()
axs.axis("off")
fig, axs = plt.subplots(squeeze=False)
axs[0, 0].imshow(sample["image"].permute(1, 2, 0))
axs[0, 0].axis("off")
if suptitle is not None:
plt.suptitle(suptitle)
return plt.gcf()
return fig
if show_feats != "boxes":
try:
@ -437,11 +438,9 @@ class VHR10(NonGeoDataset):
ncols += 1
# Display image
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
if not isinstance(axs, np.ndarray):
axs = [axs]
axs[0].imshow(image)
axs[0].axis("off")
fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10))
axs[0, 0].imshow(image)
axs[0, 0].axis("off")
cm = plt.get_cmap("gist_rainbow")
for i in range(n_gt):
@ -451,7 +450,7 @@ class VHR10(NonGeoDataset):
# Add bounding boxes
x1, y1, x2, y2 = boxes[i]
if show_feats in {"boxes", "both"}:
p = patches.Rectangle(
r = patches.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
@ -461,12 +460,12 @@ class VHR10(NonGeoDataset):
edgecolor=color,
facecolor="none",
)
axs[0].add_patch(p)
axs[0, 0].add_patch(r)
# Add labels
label = self.categories[class_num]
caption = label
axs[0].text(
axs[0, 0].text(
x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none"
)
@ -479,14 +478,14 @@ class VHR10(NonGeoDataset):
p = patches.Polygon(
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
)
axs[0].add_patch(p)
axs[0, 0].add_patch(p)
if show_titles:
axs[0].set_title("Ground Truth")
axs[0, 0].set_title("Ground Truth")
if show_predictions:
axs[1].imshow(image)
axs[1].axis("off")
axs[0, 1].imshow(image)
axs[0, 1].axis("off")
for i in range(n_pred):
score = prediction_scores[i]
if score < 0.5:
@ -498,7 +497,7 @@ class VHR10(NonGeoDataset):
if show_pred_boxes:
# Add bounding boxes
x1, y1, x2, y2 = prediction_boxes[i]
p = patches.Rectangle(
r = patches.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
@ -508,12 +507,12 @@ class VHR10(NonGeoDataset):
edgecolor=color,
facecolor="none",
)
axs[1].add_patch(p)
axs[0, 1].add_patch(r)
# Add labels
label = self.categories[class_num]
caption = f"{label} {score:.3f}"
axs[1].text(
axs[0, 1].text(
x1,
y1 - 8,
caption,
@ -531,10 +530,10 @@ class VHR10(NonGeoDataset):
p = patches.Polygon(
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
)
axs[1].add_patch(p)
axs[0, 1].add_patch(p)
if show_titles:
axs[1].set_title("Prediction")
axs[0, 1].set_title("Prediction")
plt.tight_layout()

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
@ -225,7 +226,7 @@ class XView2(NonGeoDataset):
show_titles: bool = True,
suptitle: Optional[str] = None,
alpha: float = 0.5,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args:

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Callable, Optional
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from torch import Tensor
from .geo import NonGeoDataset
@ -267,7 +268,7 @@ class ZueriCrop(NonGeoDataset):
time_step: int = 0,
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
) -> Figure:
"""Plot a sample from the dataset.
Args: