Update transforms for new Dict[str, Any] style

This commit is contained in:
Adam J. Stewart 2021-06-11 20:50:56 +00:00
Родитель 39ea1be875
Коммит eacb5685f3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
2 изменённых файлов: 56 добавлений и 37 удалений

Просмотреть файл

@ -1,8 +1,11 @@
[mypy]
python_version = 3.9
ignore_missing_imports = True
show_error_codes = True
# Strict
warn_unused_configs = True
disallow_any_generics = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
@ -12,4 +15,5 @@ no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
#no_implicit_reexport = True
no_implicit_reexport = True
strict_equality = True

Просмотреть файл

@ -1,60 +1,75 @@
from typing import Dict, Optional, Tuple
from typing import Any, Dict
import torch
from torch import Tensor
import torchvision.transforms as T
import torch.nn as nn
import torchvision.transforms.functional as F
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
# TODO: figure out why mypy is angry:
# https://discuss.pytorch.org/t/how-to-correctly-annotate-subclasses-of-nn-module/74317/2
class RandomHorizontalFlip(nn.Module): # type: ignore[misc,name-defined]
"""Horizontally flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Parameters:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Randomly flip the image and target tensors.
Parameters:
image: image to be flipped
target: optional bounding boxes and masks to flip
sample: a single data sample
Returns:
randomly flipped image and target
a possibly flipped sample
"""
if torch.rand(1) < self.p:
image = F.hflip(image)
if "image" in sample:
sample["image"] = F.hflip(sample["image"])
width, height = F._get_image_size(sample["image"])
if target is not None:
width, height = F._get_image_size(image)
if "boxes" in sample:
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
if "masks" in sample:
sample["masks"] = sample["masks"].flip(-1)
if "boxes" in target:
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
return image, target
return sample
class RandomVerticalFlip(T.RandomVerticalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
class RandomVerticalFlip(nn.Module): # type: ignore[misc,name-defined]
"""Vertically flip the given sample randomly with a given probability."""
def __init__(self, p: float = 0.5) -> None:
"""Initialize a new transform instance.
Parameters:
p: probability of the sample being flipped
"""
super().__init__()
self.p = p
def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Randomly flip the image and target tensors.
Parameters:
image: image to be flipped
target: optional bounding boxes and masks to flip
sample: a single data sample
Returns:
randomly flipped image and target
a possibly flipped sample
"""
if torch.rand(1) < self.p:
image = F.vflip(image)
if "image" in sample:
sample["image"] = F.vflip(sample["image"])
width, height = F._get_image_size(sample["image"])
if target is not None:
width, height = F._get_image_size(image)
if "boxes" in sample:
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
if "masks" in sample:
sample["masks"] = sample["masks"].flip(-2)
if "boxes" in target:
target["boxes"][:, [1, 3]] = height - target["boxes"][:, [3, 1]]
if "masks" in target:
target["masks"] = target["masks"].flip(-2)
return image, target
return sample