This commit is contained in:
Wei-ge Chen 2023-04-26 11:28:23 -07:00
Родитель 8e89402d40
Коммит 7b334a2a4f
1 изменённых файлов: 0 добавлений и 9 удалений

Просмотреть файл

@ -161,21 +161,12 @@ def train(args, model: nn.Module = None):
collate_fn = None
num_classes = dataset.dataset.num_landmarks
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True