зеркало из https://github.com/microsoft/torchgeo.git
Bump matplotlib from 3.7.3 to 3.8.0 in /requirements (#1566)
* Bump matplotlib from 3.7.3 to 3.8.0 in /requirements Bumps [matplotlib](https://github.com/matplotlib/matplotlib) from 3.7.3 to 3.8.0. - [Release notes](https://github.com/matplotlib/matplotlib/releases) - [Commits](https://github.com/matplotlib/matplotlib/compare/v3.7.3...v3.8.0) --- updated-dependencies: - dependency-name: matplotlib dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Use canonical import location * Use canonical import location * Use canonical import location * More type fixes * More type fixes * More type fixes * More type fixes --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
This commit is contained in:
Родитель
f641d075f9
Коммит
2fbdc85efd
|
@ -46,7 +46,9 @@ df = pd.read_csv("band_data.csv", skip_blank_lines=True)
|
|||
df = df.iloc[::-1]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(5.5, args.fig_height))
|
||||
ax1, ax2 = fig.subplots(nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]})
|
||||
ax1, ax2 = fig.subplots(
|
||||
nrows=1, ncols=2, gridspec_kw={"width_ratios": [3, 1]}
|
||||
) # type: ignore[misc]
|
||||
|
||||
sensor_names: list[str] = []
|
||||
sensor_ylocs: list[float] = []
|
||||
|
@ -161,4 +163,4 @@ ax2.plot(0, 0, transform=ax2.transAxes, **kwargs)
|
|||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(wspace=0.05)
|
||||
plt.show()
|
||||
plt.show() # type: ignore[no-untyped-call]
|
||||
|
|
|
@ -74,7 +74,7 @@ global_xmax = date.today()
|
|||
|
||||
fig, ax = plt.subplots(figsize=(5.5, 3))
|
||||
|
||||
cmap = iter(plt.cm.tab10(range(9, 0, -1)))
|
||||
cmap = iter(plt.cm.tab10(range(9, 0, -1))) # type: ignore[attr-defined]
|
||||
ymin = args.bar_start
|
||||
yticks = []
|
||||
for satellite in range(9, 0, -1):
|
||||
|
@ -141,4 +141,4 @@ ax.tick_params(axis="both", which="both", top=False, right=False)
|
|||
ax.spines[["top", "right"]].set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
plt.show() # type: ignore[no-untyped-call]
|
||||
|
|
|
@ -74,4 +74,4 @@ ax.legend(fontsize="large")
|
|||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
plt.show() # type: ignore[no-untyped-call]
|
||||
|
|
|
@ -32,7 +32,7 @@ for i, (label, df) in enumerate(other):
|
|||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256], fontsize=12)
|
||||
ax.set_xticklabels(["16", "32", "64", "128", "256"], fontsize=12)
|
||||
ax.set_xlabel("batch size", fontsize=12)
|
||||
ax.set_ylabel("sampling rate (patches/sec)", fontsize=12)
|
||||
ax.legend(loc="center right", fontsize="large")
|
||||
|
@ -40,4 +40,4 @@ ax.legend(loc="center right", fontsize="large")
|
|||
plt.gca().spines.right.set_visible(False)
|
||||
plt.gca().spines.top.set_visible(False)
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
plt.show() # type: ignore[no-untyped-call]
|
||||
|
|
|
@ -49,8 +49,8 @@ for i, (label, df) in enumerate(other):
|
|||
|
||||
ax.set_xscale("log")
|
||||
ax.set_xticks([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels([16, 32, 64, 128, 256])
|
||||
ax.set_xticklabels(["16", "32", "64", "128", "256"])
|
||||
ax.set_xlabel("batch size")
|
||||
ax.set_ylabel("% sampling rate (patches/sec)")
|
||||
ax.legend()
|
||||
plt.show()
|
||||
plt.show() # type: ignore[no-untyped-call]
|
||||
|
|
|
@ -7,7 +7,7 @@ fiona==1.9.4.post1
|
|||
kornia==0.7.0
|
||||
lightly==1.4.19
|
||||
lightning==2.0.9
|
||||
matplotlib==3.7.3
|
||||
matplotlib==3.8.0
|
||||
numpy==1.26.0
|
||||
pillow==10.0.1
|
||||
pyproj==3.6.0
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
import torch
|
||||
from _pytest.fixtures import SubRequest
|
||||
from lightning.pytorch import Trainer
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -33,7 +34,7 @@ class CustomGeoDataset(GeoDataset):
|
|||
image = torch.arange(3 * 2 * 2).view(3, 2, 2)
|
||||
return {"image": image, "crs": CRS.from_epsg(4326), "bbox": query}
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
||||
return plt.figure()
|
||||
|
||||
|
||||
|
@ -72,7 +73,7 @@ class CustomNonGeoDataset(NonGeoDataset):
|
|||
def __len__(self) -> int:
|
||||
return self.length
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
||||
return plt.figure()
|
||||
|
||||
|
||||
|
|
|
@ -6,9 +6,9 @@
|
|||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
import kornia.augmentation as K
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, default_collate
|
||||
|
||||
|
@ -141,7 +141,7 @@ class BaseDataModule(LightningDataModule):
|
|||
|
||||
return batch
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Optional[Figure]:
|
||||
"""Run the plot method of the validation dataset if one exists.
|
||||
|
||||
Should only be called during 'fit' or 'validate' stages as ``val_dataset``
|
||||
|
@ -154,10 +154,12 @@ class BaseDataModule(LightningDataModule):
|
|||
Returns:
|
||||
A matplotlib Figure with the image, ground truth, and predictions.
|
||||
"""
|
||||
fig: Optional[Figure] = None
|
||||
dataset = self.dataset or self.val_dataset
|
||||
if dataset is not None:
|
||||
if hasattr(dataset, "plot"):
|
||||
return dataset.plot(*args, **kwargs)
|
||||
fig = dataset.plot(*args, **kwargs)
|
||||
return fig
|
||||
|
||||
|
||||
class GeoDataModule(BaseDataModule):
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import kornia.augmentation as K
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from ..datasets import NAIP, BoundingBox, Chesapeake13
|
||||
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
|
||||
|
@ -95,7 +95,7 @@ class NAIPChesapeakeDataModule(GeoDataModule):
|
|||
self.dataset, self.patch_size, self.patch_size, test_roi
|
||||
)
|
||||
|
||||
def plot(self, *args: Any, **kwargs: Any) -> plt.Figure:
|
||||
def plot(self, *args: Any, **kwargs: Any) -> Figure:
|
||||
"""Run NAIP plot method.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -236,7 +237,7 @@ class ADVANCE(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -129,7 +130,7 @@ class AbovegroundLiveWoodyBiomassDensity(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -97,7 +98,7 @@ class AsterGDEM(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,6 +13,7 @@ import numpy as np
|
|||
import rasterio
|
||||
import rasterio.features
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -431,7 +432,7 @@ class BeninSmallHolderCashews(NonGeoDataset):
|
|||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -533,7 +534,7 @@ class BigEarthNet(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import VectorDataset
|
||||
|
@ -127,7 +128,7 @@ class CanadianBuildingFootprints(VectorDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -347,7 +348,7 @@ class CDL(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -19,6 +19,7 @@ import shapely.geometry
|
|||
import shapely.ops
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -173,7 +174,7 @@ class Chesapeake(RasterDataset, abc.ABC):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -743,7 +744,7 @@ class ChesapeakeCVPR(GeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -355,7 +356,7 @@ class CloudCoverDetection(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -256,7 +257,7 @@ class CMSGlobalMangroveCanopy(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -196,7 +197,7 @@ class COWC(NonGeoDataset, abc.ABC):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -411,7 +412,7 @@ class CV4AKenyaCropType(NonGeoDataset):
|
|||
show_titles: bool = True,
|
||||
time_step: int = 0,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any, Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -227,7 +228,7 @@ class TropicalCyclone(NonGeoDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,6 +13,7 @@ import numpy as np
|
|||
import rasterio
|
||||
import torch
|
||||
from matplotlib import colors
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -298,7 +299,7 @@ class DFC2022(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -18,6 +18,7 @@ import shapely.geometry
|
|||
import shapely.ops
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import GeoDataset
|
||||
|
@ -454,7 +455,7 @@ class EnviroAtlas(GeoDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Note: only plots the "naip" and "lc" layers.
|
||||
|
|
|
@ -8,6 +8,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -142,7 +143,7 @@ class Esri2020(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -263,7 +264,7 @@ class ETCI2021(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ import os
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -144,7 +145,7 @@ class EUDEM(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
|
@ -261,7 +262,7 @@ class EuroSAT(NonGeoClassificationDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.patches as patches
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -395,7 +396,7 @@ class FAIR1M(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
from typing import Callable, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
|
@ -144,7 +145,7 @@ class FireRisk(NonGeoClassificationDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.patches as patches
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -263,7 +264,7 @@ class ForestDamage(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -241,9 +242,7 @@ class GID15(NonGeoDataset):
|
|||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self, sample: dict[str, Tensor], suptitle: Optional[str] = None
|
||||
) -> plt.Figure:
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, cast
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -230,7 +231,7 @@ class GlobBiomass(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.enums import Resampling
|
||||
from torch import Tensor
|
||||
from torchvision.ops import clip_boxes_to_image, remove_small_boxes
|
||||
|
@ -496,7 +497,7 @@ class IDTReeS(NonGeoDataset):
|
|||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
hsi_indices: tuple[int, int, int] = (0, 1, 2),
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from collections.abc import Sequence
|
|||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -223,7 +224,7 @@ class L7Irish(RasterDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from collections.abc import Sequence
|
|||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -219,7 +220,7 @@ class L8Biome(RasterDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,6 +13,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from rasterio.crs import CRS
|
||||
from torch import Tensor
|
||||
|
@ -155,7 +156,7 @@ class LandCoverAIBase(Dataset[dict[str, Any]], abc.ABC):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ from collections.abc import Sequence
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -90,7 +91,7 @@ class Landsat(RasterDataset, abc.ABC):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -213,7 +214,7 @@ class LEVIRCDPlus(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -264,9 +265,7 @@ class LoveDA(NonGeoDataset):
|
|||
md5=self.md5 if self.checksum else None,
|
||||
)
|
||||
|
||||
def plot(
|
||||
self, sample: dict[str, Tensor], suptitle: Optional[str] = None
|
||||
) -> plt.Figure:
|
||||
def plot(self, sample: dict[str, Tensor], suptitle: Optional[str] = None) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -336,7 +337,7 @@ class MillionAID(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
||||
|
@ -52,7 +53,7 @@ class NAIP(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
from torchvision.utils import draw_bounding_boxes
|
||||
|
||||
|
@ -224,7 +225,7 @@ class NASAMarineDebris(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -245,7 +246,7 @@ class NLCD(RasterDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -16,6 +16,7 @@ import rasterio
|
|||
import shapely
|
||||
import shapely.wkt as wkt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
from rtree.index import Index, Property
|
||||
|
||||
|
@ -434,7 +435,7 @@ class OpenBuildings(VectorDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.colors import ListedColormap
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -351,7 +352,7 @@ class PASTIS(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -7,6 +7,7 @@ import os
|
|||
from typing import Callable, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
|
@ -150,7 +151,7 @@ class PatternNet(NonGeoClassificationDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ import matplotlib.patches as patches
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -223,7 +224,7 @@ class ReforesTree(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Callable, Optional, cast
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
|
@ -246,7 +247,7 @@ class RESISC45(NonGeoClassificationDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -232,7 +233,7 @@ class SeasonalContrastS2(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -317,7 +318,7 @@ class SEN12MS(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Callable, Optional
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from rasterio.crs import CRS
|
||||
|
||||
from .geo import RasterDataset
|
||||
|
@ -190,7 +191,7 @@ To create a dataset containing both, use:
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -326,7 +327,7 @@ class Sentinel2(Sentinel):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional, Union
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -226,7 +227,7 @@ class SKIPPD(NonGeoDataset):
|
|||
sample: dict[str, Any],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -335,7 +336,7 @@ class So2Sat(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,6 +12,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -293,7 +294,7 @@ class SSL4EOL(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -512,7 +513,7 @@ class SSL4EOS12(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -11,6 +11,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .cdl import CDL
|
||||
|
@ -332,7 +333,7 @@ class SSL4EOLBenchmark(NonGeoDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -238,7 +239,7 @@ class SustainBenchCropYield(NonGeoDataset):
|
|||
band_idx: int = 0,
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Callable, Optional, cast
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torchvision.transforms.functional as F
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoClassificationDataset
|
||||
|
@ -222,7 +223,7 @@ class UCMerced(NonGeoClassificationDataset):
|
|||
sample: dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
|
|||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import patches
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -371,7 +372,7 @@ class VHR10(NonGeoDataset):
|
|||
show_feats: Optional[str] = "both",
|
||||
box_alpha: float = 0.7,
|
||||
mask_alpha: float = 0.7,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
@ -394,13 +395,13 @@ class VHR10(NonGeoDataset):
|
|||
assert show_feats in {"boxes", "masks", "both"}
|
||||
|
||||
if self.split == "negative":
|
||||
plt.imshow(sample["image"].permute(1, 2, 0))
|
||||
axs = plt.gca()
|
||||
axs.axis("off")
|
||||
fig, axs = plt.subplots(squeeze=False)
|
||||
axs[0, 0].imshow(sample["image"].permute(1, 2, 0))
|
||||
axs[0, 0].axis("off")
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
return plt.gcf()
|
||||
return fig
|
||||
|
||||
if show_feats != "boxes":
|
||||
try:
|
||||
|
@ -437,11 +438,9 @@ class VHR10(NonGeoDataset):
|
|||
ncols += 1
|
||||
|
||||
# Display image
|
||||
fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10))
|
||||
if not isinstance(axs, np.ndarray):
|
||||
axs = [axs]
|
||||
axs[0].imshow(image)
|
||||
axs[0].axis("off")
|
||||
fig, axs = plt.subplots(ncols=ncols, squeeze=False, figsize=(ncols * 10, 10))
|
||||
axs[0, 0].imshow(image)
|
||||
axs[0, 0].axis("off")
|
||||
|
||||
cm = plt.get_cmap("gist_rainbow")
|
||||
for i in range(n_gt):
|
||||
|
@ -451,7 +450,7 @@ class VHR10(NonGeoDataset):
|
|||
# Add bounding boxes
|
||||
x1, y1, x2, y2 = boxes[i]
|
||||
if show_feats in {"boxes", "both"}:
|
||||
p = patches.Rectangle(
|
||||
r = patches.Rectangle(
|
||||
(x1, y1),
|
||||
x2 - x1,
|
||||
y2 - y1,
|
||||
|
@ -461,12 +460,12 @@ class VHR10(NonGeoDataset):
|
|||
edgecolor=color,
|
||||
facecolor="none",
|
||||
)
|
||||
axs[0].add_patch(p)
|
||||
axs[0, 0].add_patch(r)
|
||||
|
||||
# Add labels
|
||||
label = self.categories[class_num]
|
||||
caption = label
|
||||
axs[0].text(
|
||||
axs[0, 0].text(
|
||||
x1, y1 - 8, caption, color="white", size=11, backgroundcolor="none"
|
||||
)
|
||||
|
||||
|
@ -479,14 +478,14 @@ class VHR10(NonGeoDataset):
|
|||
p = patches.Polygon(
|
||||
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
|
||||
)
|
||||
axs[0].add_patch(p)
|
||||
axs[0, 0].add_patch(p)
|
||||
|
||||
if show_titles:
|
||||
axs[0].set_title("Ground Truth")
|
||||
axs[0, 0].set_title("Ground Truth")
|
||||
|
||||
if show_predictions:
|
||||
axs[1].imshow(image)
|
||||
axs[1].axis("off")
|
||||
axs[0, 1].imshow(image)
|
||||
axs[0, 1].axis("off")
|
||||
for i in range(n_pred):
|
||||
score = prediction_scores[i]
|
||||
if score < 0.5:
|
||||
|
@ -498,7 +497,7 @@ class VHR10(NonGeoDataset):
|
|||
if show_pred_boxes:
|
||||
# Add bounding boxes
|
||||
x1, y1, x2, y2 = prediction_boxes[i]
|
||||
p = patches.Rectangle(
|
||||
r = patches.Rectangle(
|
||||
(x1, y1),
|
||||
x2 - x1,
|
||||
y2 - y1,
|
||||
|
@ -508,12 +507,12 @@ class VHR10(NonGeoDataset):
|
|||
edgecolor=color,
|
||||
facecolor="none",
|
||||
)
|
||||
axs[1].add_patch(p)
|
||||
axs[0, 1].add_patch(r)
|
||||
|
||||
# Add labels
|
||||
label = self.categories[class_num]
|
||||
caption = f"{label} {score:.3f}"
|
||||
axs[1].text(
|
||||
axs[0, 1].text(
|
||||
x1,
|
||||
y1 - 8,
|
||||
caption,
|
||||
|
@ -531,10 +530,10 @@ class VHR10(NonGeoDataset):
|
|||
p = patches.Polygon(
|
||||
verts, facecolor=color, alpha=mask_alpha, edgecolor="white"
|
||||
)
|
||||
axs[1].add_patch(p)
|
||||
axs[0, 1].add_patch(p)
|
||||
|
||||
if show_titles:
|
||||
axs[1].set_title("Prediction")
|
||||
axs[0, 1].set_title("Prediction")
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Callable, Optional
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -225,7 +226,7 @@ class XView2(NonGeoDataset):
|
|||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
alpha: float = 0.5,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Callable, Optional
|
|||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import NonGeoDataset
|
||||
|
@ -267,7 +268,7 @@ class ZueriCrop(NonGeoDataset):
|
|||
time_step: int = 0,
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
) -> Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
|
|
Загрузка…
Ссылка в новой задаче