зеркало из https://github.com/microsoft/archai.git
Further clean up
This commit is contained in:
Родитель
8e89402d40
Коммит
7b334a2a4f
|
@ -161,21 +161,12 @@ def train(args, model: nn.Module = None):
|
|||
|
||||
collate_fn = None
|
||||
num_classes = dataset.dataset.num_landmarks
|
||||
mixup_transforms = []
|
||||
if args.mixup_alpha > 0.0:
|
||||
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
|
||||
if args.cutmix_alpha > 0.0:
|
||||
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
|
||||
if mixup_transforms:
|
||||
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
|
||||
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True
|
||||
|
|
Загрузка…
Ссылка в новой задаче