From 7b334a2a4ffd91e1195feb05c52b18364051ccdd Mon Sep 17 00:00:00 2001 From: Wei-ge Chen Date: Wed, 26 Apr 2023 11:28:23 -0700 Subject: [PATCH] Further clean up --- tasks/facial_landmark_detection/train.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tasks/facial_landmark_detection/train.py b/tasks/facial_landmark_detection/train.py index f84da3dd..372267dc 100644 --- a/tasks/facial_landmark_detection/train.py +++ b/tasks/facial_landmark_detection/train.py @@ -161,21 +161,12 @@ def train(args, model: nn.Module = None): collate_fn = None num_classes = dataset.dataset.num_landmarks - mixup_transforms = [] - if args.mixup_alpha > 0.0: - mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) - if args.cutmix_alpha > 0.0: - mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) - if mixup_transforms: - mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) - collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=collate_fn, ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True