Add plot method to PatternNet dataset (#314)

* Adding plot method to PatternNet dataset

* classes and doc fix

* remove classes and adjust directory

* fix test

* handling directory and documentation

* md5

* md5

* semicolon
This commit is contained in:
Nils Lehmann 2021-12-31 18:00:15 +01:00 коммит произвёл GitHub
Родитель c90419b38f
Коммит 2375a512be
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 53 добавлений и 2 удалений

Двоичные данные
tests/data/patternnet/PatternNet.zip

Двоичный файл не отображается.

Просмотреть файл

@ -6,6 +6,7 @@ import shutil
from pathlib import Path
from typing import Generator
import matplotlib.pyplot as plt
import pytest
import torch
import torch.nn as nn
@ -61,3 +62,12 @@ class TestPatternNet:
"to automaticaly download the dataset."
with pytest.raises(RuntimeError, match=err):
PatternNet(str(tmp_path))
def test_plot(self, dataset: PatternNet) -> None:
dataset.plot(dataset[0], suptitle="Test")
plt.close()
sample = dataset[0]
sample["prediction"] = sample["label"].clone()
dataset.plot(sample, suptitle="Prediction")
plt.close()

Просмотреть файл

@ -4,8 +4,9 @@
"""PatternNet dataset."""
import os
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, cast
import matplotlib.pyplot as plt
from torch import Tensor
from .geo import VisionClassificationDataset
@ -77,7 +78,7 @@ class PatternNet(VisionClassificationDataset):
url = "https://drive.google.com/file/d/127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K"
md5 = "96d54b3224c5350a98d55d5a7e6984ad"
filename = "PatternNet.zip"
directory = "images"
directory = os.path.join("PatternNet", "images")
def __init__(
self,
@ -143,3 +144,43 @@ class PatternNet(VisionClassificationDataset):
"""Extract the dataset."""
filepath = os.path.join(self.root, self.filename)
extract_archive(filepath)
def plot(
self,
sample: Dict[str, Tensor],
show_titles: bool = True,
suptitle: Optional[str] = None,
) -> plt.Figure:
"""Plot a sample from the dataset.
Args:
sample: a sample returned by :meth:`VisionClassificationDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional suptitle to use for figure
Returns:
a matplotlib Figure with the rendered sample
.. versionadded:: 0.2
"""
image, label = sample["image"], cast(int, sample["label"].item())
showing_predictions = "prediction" in sample
if showing_predictions:
prediction = cast(int, sample["prediction"].item())
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(image.permute(1, 2, 0))
ax.axis("off")
if show_titles:
title = f"Label: {self.classes[label]}"
if showing_predictions:
title += f"\nPrediction: {self.classes[prediction]}"
ax.set_title(title)
if suptitle is not None:
plt.suptitle(suptitle)
return fig