зеркало из https://github.com/microsoft/torchgeo.git
Add plot method to PatternNet dataset (#314)
* Adding plot method to PatternNet dataset * classes and doc fix * remove classes and adjust directory * fix test * handling directory and documentation * md5 * md5 * semicolon
This commit is contained in:
Родитель
c90419b38f
Коммит
2375a512be
Двоичные данные
tests/data/patternnet/PatternNet.zip
Двоичные данные
tests/data/patternnet/PatternNet.zip
Двоичный файл не отображается.
|
@ -6,6 +6,7 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -61,3 +62,12 @@ class TestPatternNet:
|
|||
"to automaticaly download the dataset."
|
||||
with pytest.raises(RuntimeError, match=err):
|
||||
PatternNet(str(tmp_path))
|
||||
|
||||
def test_plot(self, dataset: PatternNet) -> None:
|
||||
dataset.plot(dataset[0], suptitle="Test")
|
||||
plt.close()
|
||||
|
||||
sample = dataset[0]
|
||||
sample["prediction"] = sample["label"].clone()
|
||||
dataset.plot(sample, suptitle="Prediction")
|
||||
plt.close()
|
||||
|
|
|
@ -4,8 +4,9 @@
|
|||
"""PatternNet dataset."""
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, Optional
|
||||
from typing import Callable, Dict, Optional, cast
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import Tensor
|
||||
|
||||
from .geo import VisionClassificationDataset
|
||||
|
@ -77,7 +78,7 @@ class PatternNet(VisionClassificationDataset):
|
|||
url = "https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K"
|
||||
md5 = "96d54b3224c5350a98d55d5a7e6984ad"
|
||||
filename = "PatternNet.zip"
|
||||
directory = "images"
|
||||
directory = os.path.join("PatternNet", "images")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -143,3 +144,43 @@ class PatternNet(VisionClassificationDataset):
|
|||
"""Extract the dataset."""
|
||||
filepath = os.path.join(self.root, self.filename)
|
||||
extract_archive(filepath)
|
||||
|
||||
def plot(
|
||||
self,
|
||||
sample: Dict[str, Tensor],
|
||||
show_titles: bool = True,
|
||||
suptitle: Optional[str] = None,
|
||||
) -> plt.Figure:
|
||||
"""Plot a sample from the dataset.
|
||||
|
||||
Args:
|
||||
sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__`
|
||||
show_titles: flag indicating whether to show titles above each panel
|
||||
suptitle: optional suptitle to use for figure
|
||||
|
||||
Returns:
|
||||
a matplotlib Figure with the rendered sample
|
||||
|
||||
.. versionadded:: 0.2
|
||||
"""
|
||||
image, label = sample["image"], cast(int, sample["label"].item())
|
||||
|
||||
showing_predictions = "prediction" in sample
|
||||
if showing_predictions:
|
||||
prediction = cast(int, sample["prediction"].item())
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
||||
|
||||
ax.imshow(image.permute(1, 2, 0))
|
||||
ax.axis("off")
|
||||
|
||||
if show_titles:
|
||||
title = f"Label: {self.classes[label]}"
|
||||
if showing_predictions:
|
||||
title += f"\nPrediction: {self.classes[prediction]}"
|
||||
ax.set_title(title)
|
||||
|
||||
if suptitle is not None:
|
||||
plt.suptitle(suptitle)
|
||||
|
||||
return fig
|
||||
|
|
Загрузка…
Ссылка в новой задаче