feat(seg): adds {tr,val}_dataloader_workers arg

This commit is contained in:
Piero 2022-05-30 15:32:34 -07:00 коммит произвёл Gustavo Rosa
Родитель 0010ad7cd7
Коммит 008e573b5f
3 изменённых файлов: 15 добавлений и 3 удалений

Просмотреть файл

@ -368,7 +368,10 @@ class EvolutionParetoSearchSegmentation(EvolutionParetoSearch):
augmentation=self.augmentation, augmentation=self.augmentation,
batch_size=self.batch_size, lr=self.lr, batch_size=self.batch_size, lr=self.lr,
lr_exp_decay_gamma=self.lr_exp_decay_gamma, lr_exp_decay_gamma=self.lr_exp_decay_gamma,
criterion_name=self.criterion_name, seed=self.seed) criterion_name=self.criterion_name, seed=self.seed,
tr_dataloader_workers=self.conf_train['tr_dataloader_workers'],
val_dataloader_workers=self.conf_train['val_dataloader_workers'],
)
return ref return ref
@overrides @overrides

Просмотреть файл

@ -159,6 +159,8 @@ class SegmentationTrainer():
lr: float = 2e-4, criterion_name: str = 'ce', lr: float = 2e-4, criterion_name: str = 'ce',
val_check_interval: Union[int, float] = 0.25, val_check_interval: Union[int, float] = 0.25,
lr_exp_decay_gamma: float = 0.98, lr_exp_decay_gamma: float = 0.98,
tr_dataloader_workers: int = 3,
val_dataloader_workers: int = 1,
seed: int = 1): seed: int = 1):
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
@ -172,8 +174,12 @@ class SegmentationTrainer():
self.val_dataset = FaceSynthetics(self.data_dir, subset='validation', val_size=val_size, self.val_dataset = FaceSynthetics(self.data_dir, subset='validation', val_size=val_size,
img_size=(img_size, img_size), augmentation=augmentation) img_size=(img_size, img_size), augmentation=augmentation)
self.tr_dataloader = DataLoader(self.tr_dataset, batch_size=batch_size, num_workers=8, shuffle=True) self.tr_dataloader = DataLoader(
self.val_dataloader = DataLoader(self.val_dataset, batch_size=batch_size, num_workers=8, shuffle=False) self.tr_dataset, batch_size=batch_size, num_workers=tr_dataloader_workers, shuffle=True
)
self.val_dataloader = DataLoader(
self.val_dataset, batch_size=batch_size, num_workers=val_dataloader_workers, shuffle=False
)
self.model = LightningModelWrapper(model, criterion_name=criterion_name, lr=lr, self.model = LightningModelWrapper(model, criterion_name=criterion_name, lr=lr,
img_size=img_size, lr_exp_decay_gamma=lr_exp_decay_gamma) img_size=img_size, lr_exp_decay_gamma=lr_exp_decay_gamma)

Просмотреть файл

@ -115,3 +115,6 @@ nas:
criterion_name: 'ce' criterion_name: 'ce'
val_check_interval: 1.0 # how often to evaluate validation accuracy per epoch val_check_interval: 1.0 # how often to evaluate validation accuracy per epoch
lr_exp_decay_gamma: 0.973435286 lr_exp_decay_gamma: 0.973435286
tr_dataloader_workers: 3
val_dataloader_workers: 1