From d552502c6dff4e3bc584d82044baf646bd0fac35 Mon Sep 17 00:00:00 2001 From: Wei-ge Chen Date: Thu, 20 Apr 2023 13:36:14 -0700 Subject: [PATCH] simplying code --- tasks/facial_landmark_detection/transforms.py | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/tasks/facial_landmark_detection/transforms.py b/tasks/facial_landmark_detection/transforms.py index f4f41b7a..5232d816 100644 --- a/tasks/facial_landmark_detection/transforms.py +++ b/tasks/facial_landmark_detection/transforms.py @@ -84,36 +84,30 @@ class WarpRegion(): return points_h[:, :2] / points_h[:, 2][..., None] # Dehomogenize -class GetWarpRegion(): - """Builds a warp region for the face, with square bounds enclosing the given 2D landmarks.""" +class ExtractWarpRegion(): + """Extract the Warp Region from a sample.""" def __init__(self, roi_size, scale=2.0): self.roi_size = roi_size self.scale = scale + self.kwargs_bgr = {"flags": cv2.INTER_AREA, "borderMode": cv2.BORDER_REPLICATE} - def __call__(self, sample: Sample): + def get_warp_region(self, sample: Sample): assert sample.landmarks is not None ldmks_2d = sample.landmarks - sample.warp_region = WarpRegion(*get_square_bounds(ldmks_2d), self.roi_size) - sample.warp_region.scale(self.scale) + warp_region = WarpRegion(*get_square_bounds(ldmks_2d), self.roi_size) + warp_region.scale(self.scale) - return sample - -class ExtractWarpRegion(): - """Extract the Warp Region from a sample.""" - - def __init__(self): - self.kwargs_bgr = {"flags": cv2.INTER_AREA, "borderMode": cv2.BORDER_REPLICATE} + return warp_region def __call__(self, sample : tuple): assert sample.image is not None - assert sample.warp_region is not None - warp_region = sample.warp_region + warp_region = self.get_warp_region(sample) sample.image = warp_region.extract_from_image(sample.image, **self.kwargs_bgr) @@ -150,8 +144,7 @@ class FaceLandmarkTransform: ): self.transform = Compose( [ - GetWarpRegion(roi_size = crop_size), - ExtractWarpRegion(), + ExtractWarpRegion(roi_size = crop_size), SampleToTensor(), NormalizeCoordinates() ]