зеркало из https://github.com/microsoft/torchgeo.git
Update transforms for new Dict[str, Any] style
This commit is contained in:
Родитель
39ea1be875
Коммит
eacb5685f3
8
mypy.ini
8
mypy.ini
|
@ -1,8 +1,11 @@
|
|||
[mypy]
|
||||
python_version = 3.9
|
||||
ignore_missing_imports = True
|
||||
show_error_codes = True
|
||||
|
||||
# Strict
|
||||
warn_unused_configs = True
|
||||
disallow_any_generics = True
|
||||
disallow_subclassing_any = True
|
||||
disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
disallow_incomplete_defs = True
|
||||
|
@ -12,4 +15,5 @@ no_implicit_optional = True
|
|||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
warn_return_any = True
|
||||
#no_implicit_reexport = True
|
||||
no_implicit_reexport = True
|
||||
strict_equality = True
|
||||
|
|
|
@ -1,60 +1,75 @@
|
|||
from typing import Dict, Optional, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torchvision.transforms as T
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as F
|
||||
|
||||
|
||||
class RandomHorizontalFlip(T.RandomHorizontalFlip):
|
||||
def forward(
|
||||
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
||||
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||||
# TODO: figure out why mypy is angry:
|
||||
# https://discuss.pytorch.org/t/how-to-correctly-annotate-subclasses-of-nn-module/74317/2
|
||||
class RandomHorizontalFlip(nn.Module): # type: ignore[misc,name-defined]
|
||||
"""Horizontally flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Parameters:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Parameters:
|
||||
image: image to be flipped
|
||||
target: optional bounding boxes and masks to flip
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
randomly flipped image and target
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
image = F.hflip(image)
|
||||
if "image" in sample:
|
||||
sample["image"] = F.hflip(sample["image"])
|
||||
width, height = F._get_image_size(sample["image"])
|
||||
|
||||
if target is not None:
|
||||
width, height = F._get_image_size(image)
|
||||
if "boxes" in sample:
|
||||
sample["boxes"][:, [0, 2]] = width - sample["boxes"][:, [2, 0]]
|
||||
if "masks" in sample:
|
||||
sample["masks"] = sample["masks"].flip(-1)
|
||||
|
||||
if "boxes" in target:
|
||||
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-1)
|
||||
|
||||
return image, target
|
||||
return sample
|
||||
|
||||
|
||||
class RandomVerticalFlip(T.RandomVerticalFlip):
|
||||
def forward(
|
||||
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
|
||||
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
|
||||
class RandomVerticalFlip(nn.Module): # type: ignore[misc,name-defined]
|
||||
"""Vertically flip the given sample randomly with a given probability."""
|
||||
|
||||
def __init__(self, p: float = 0.5) -> None:
|
||||
"""Initialize a new transform instance.
|
||||
|
||||
Parameters:
|
||||
p: probability of the sample being flipped
|
||||
"""
|
||||
super().__init__()
|
||||
self.p = p
|
||||
|
||||
def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Randomly flip the image and target tensors.
|
||||
|
||||
Parameters:
|
||||
image: image to be flipped
|
||||
target: optional bounding boxes and masks to flip
|
||||
sample: a single data sample
|
||||
|
||||
Returns:
|
||||
randomly flipped image and target
|
||||
a possibly flipped sample
|
||||
"""
|
||||
if torch.rand(1) < self.p:
|
||||
image = F.vflip(image)
|
||||
if "image" in sample:
|
||||
sample["image"] = F.vflip(sample["image"])
|
||||
width, height = F._get_image_size(sample["image"])
|
||||
|
||||
if target is not None:
|
||||
width, height = F._get_image_size(image)
|
||||
if "boxes" in sample:
|
||||
sample["boxes"][:, [1, 3]] = height - sample["boxes"][:, [3, 1]]
|
||||
if "masks" in sample:
|
||||
sample["masks"] = sample["masks"].flip(-2)
|
||||
|
||||
if "boxes" in target:
|
||||
target["boxes"][:, [1, 3]] = height - target["boxes"][:, [3, 1]]
|
||||
if "masks" in target:
|
||||
target["masks"] = target["masks"].flip(-2)
|
||||
|
||||
return image, target
|
||||
return sample
|
||||
|
|
Загрузка…
Ссылка в новой задаче