This commit is contained in:
Wei-ge Chen 2023-04-20 13:36:14 -07:00
Родитель dab2b68a37
Коммит d552502c6d
1 изменённых файлов: 9 добавлений и 16 удалений

Просмотреть файл

@ -84,36 +84,30 @@ class WarpRegion():
return points_h[:, :2] / points_h[:, 2][..., None] # Dehomogenize
class GetWarpRegion():
"""Builds a warp region for the face, with square bounds enclosing the given 2D landmarks."""
class ExtractWarpRegion():
"""Extract the Warp Region from a sample."""
def __init__(self, roi_size, scale=2.0):
self.roi_size = roi_size
self.scale = scale
self.kwargs_bgr = {"flags": cv2.INTER_AREA, "borderMode": cv2.BORDER_REPLICATE}
def __call__(self, sample: Sample):
def get_warp_region(self, sample: Sample):
assert sample.landmarks is not None
ldmks_2d = sample.landmarks
sample.warp_region = WarpRegion(*get_square_bounds(ldmks_2d), self.roi_size)
sample.warp_region.scale(self.scale)
warp_region = WarpRegion(*get_square_bounds(ldmks_2d), self.roi_size)
warp_region.scale(self.scale)
return sample
class ExtractWarpRegion():
"""Extract the Warp Region from a sample."""
def __init__(self):
self.kwargs_bgr = {"flags": cv2.INTER_AREA, "borderMode": cv2.BORDER_REPLICATE}
return warp_region
def __call__(self, sample : tuple):
assert sample.image is not None
assert sample.warp_region is not None
warp_region = sample.warp_region
warp_region = self.get_warp_region(sample)
sample.image = warp_region.extract_from_image(sample.image, **self.kwargs_bgr)
@ -150,8 +144,7 @@ class FaceLandmarkTransform:
):
self.transform = Compose(
[
GetWarpRegion(roi_size = crop_size),
ExtractWarpRegion(),
ExtractWarpRegion(roi_size = crop_size),
SampleToTensor(),
NormalizeCoordinates()
]