Add random horizontal/vertical flip transforms

This commit is contained in:
Adam J. Stewart 2021-05-14 13:41:24 -05:00
Родитель 0875357b7e
Коммит a48f81af4f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: C66C0675661156FC
3 изменённых файлов: 65 добавлений и 0 удалений

Просмотреть файл

@ -5,4 +5,5 @@ mypy
Pillow
pycocotools
rarfile
torch
torchvision

Просмотреть файл

@ -0,0 +1,4 @@
from .transforms import RandomHorizontalFlip, RandomVerticalFlip
__all__ = ("RandomHorizontalFlip", "RandomVerticalFlip")

Просмотреть файл

@ -0,0 +1,60 @@
from typing import Dict, Optional, Tuple
import torch
from torch import Tensor
import torchvision.transforms as T
import torchvision.transforms.functional as F
class RandomHorizontalFlip(T.RandomHorizontalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
"""Randomly flip the image and target tensors.
Parameters:
image: image to be flipped
target: optional bounding boxes and masks to flip
Returns:
randomly flipped image and target
"""
if torch.rand(1) < self.p:
image = F.hflip(image)
if target is not None:
width, height = F._get_image_size(image)
if "boxes" in target:
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
if "masks" in target:
target["masks"] = target["masks"].flip(-1)
return image, target
class RandomVerticalFlip(T.RandomVerticalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
"""Randomly flip the image and target tensors.
Parameters:
image: image to be flipped
target: optional bounding boxes and masks to flip
Returns:
randomly flipped image and target
"""
if torch.rand(1) < self.p:
image = F.vflip(image)
if target is not None:
width, height = F._get_image_size(image)
if "boxes" in target:
target["boxes"][:, [1, 3]] = height - target["boxes"][:, [3, 1]]
if "masks" in target:
target["masks"] = target["masks"].flip(-2)
return image, target