зеркало из https://github.com/microsoft/archai.git
feat(seg): adds {tr,val}_dataloader_workers arg
This commit is contained in:
Родитель
0010ad7cd7
Коммит
008e573b5f
|
@ -368,7 +368,10 @@ class EvolutionParetoSearchSegmentation(EvolutionParetoSearch):
|
||||||
augmentation=self.augmentation,
|
augmentation=self.augmentation,
|
||||||
batch_size=self.batch_size, lr=self.lr,
|
batch_size=self.batch_size, lr=self.lr,
|
||||||
lr_exp_decay_gamma=self.lr_exp_decay_gamma,
|
lr_exp_decay_gamma=self.lr_exp_decay_gamma,
|
||||||
criterion_name=self.criterion_name, seed=self.seed)
|
criterion_name=self.criterion_name, seed=self.seed,
|
||||||
|
tr_dataloader_workers=self.conf_train['tr_dataloader_workers'],
|
||||||
|
val_dataloader_workers=self.conf_train['val_dataloader_workers'],
|
||||||
|
)
|
||||||
return ref
|
return ref
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
|
|
@ -159,6 +159,8 @@ class SegmentationTrainer():
|
||||||
lr: float = 2e-4, criterion_name: str = 'ce',
|
lr: float = 2e-4, criterion_name: str = 'ce',
|
||||||
val_check_interval: Union[int, float] = 0.25,
|
val_check_interval: Union[int, float] = 0.25,
|
||||||
lr_exp_decay_gamma: float = 0.98,
|
lr_exp_decay_gamma: float = 0.98,
|
||||||
|
tr_dataloader_workers: int = 3,
|
||||||
|
val_dataloader_workers: int = 1,
|
||||||
seed: int = 1):
|
seed: int = 1):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
@ -172,8 +174,12 @@ class SegmentationTrainer():
|
||||||
self.val_dataset = FaceSynthetics(self.data_dir, subset='validation', val_size=val_size,
|
self.val_dataset = FaceSynthetics(self.data_dir, subset='validation', val_size=val_size,
|
||||||
img_size=(img_size, img_size), augmentation=augmentation)
|
img_size=(img_size, img_size), augmentation=augmentation)
|
||||||
|
|
||||||
self.tr_dataloader = DataLoader(self.tr_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
|
self.tr_dataloader = DataLoader(
|
||||||
self.val_dataloader = DataLoader(self.val_dataset, batch_size=batch_size, num_workers=8, shuffle=False)
|
self.tr_dataset, batch_size=batch_size, num_workers=tr_dataloader_workers, shuffle=True
|
||||||
|
)
|
||||||
|
self.val_dataloader = DataLoader(
|
||||||
|
self.val_dataset, batch_size=batch_size, num_workers=val_dataloader_workers, shuffle=False
|
||||||
|
)
|
||||||
|
|
||||||
self.model = LightningModelWrapper(model, criterion_name=criterion_name, lr=lr,
|
self.model = LightningModelWrapper(model, criterion_name=criterion_name, lr=lr,
|
||||||
img_size=img_size, lr_exp_decay_gamma=lr_exp_decay_gamma)
|
img_size=img_size, lr_exp_decay_gamma=lr_exp_decay_gamma)
|
||||||
|
|
|
@ -115,3 +115,6 @@ nas:
|
||||||
criterion_name: 'ce'
|
criterion_name: 'ce'
|
||||||
val_check_interval: 1.0 # how often to evaluate validation accuracy per epoch
|
val_check_interval: 1.0 # how often to evaluate validation accuracy per epoch
|
||||||
lr_exp_decay_gamma: 0.973435286
|
lr_exp_decay_gamma: 0.973435286
|
||||||
|
tr_dataloader_workers: 3
|
||||||
|
val_dataloader_workers: 1
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче