fix(task/seg): training script changes for compat w/ newer version of lightning

This commit is contained in:
piero2c 2023-03-30 11:15:20 -07:00
Родитель f5170158b6
Коммит df190bfea0
2 изменённых файлов: 17 добавлений и 37 удалений

Просмотреть файл

@ -49,7 +49,7 @@ if __name__ == '__main__':
]
trainer = Trainer(
default_root_dir=str(args.output_dir), gpus=1,
default_root_dir=str(args.output_dir), accelerator='gpu',
val_check_interval=args.val_check_interval,
max_epochs=args.epochs,
callbacks=callbacks

Просмотреть файл

@ -40,52 +40,32 @@ class SegmentationTrainingLoop(pl.LightningModule):
pred_classes.cpu(), mask.cpu(), self.num_classes, self.ignore_mask_value
)
return {
'loss': loss,
'confusion_matrix': confusion_matrix
iou_dict = get_iou(confusion_matrix)
f1_dict = get_f1_scores(confusion_matrix)
results = {
f'{stage}_loss': loss,
f'{stage}_mIOU': iou_dict['mIOU'],
f'{stage}_macro_f1': f1_dict['macro_f1'],
f'{stage}_weighted_f1': f1_dict['weighted_f1']
}
return results
def training_step(self, batch, batch_idx):
results = self.shared_step(batch, stage='train')
self.log_dict({'training_loss': results['loss']}, sync_dist=True)
self.log_dict(results, sync_dist=True, on_step=True, on_epoch=True)
return results['train_loss']
def validation_step(self, batch, batch_idx):
results = self.shared_step(batch, stage='validation')
self.log_dict(results, sync_dist=True, on_step=False, on_epoch=True)
return results
def predict(self, image):
with torch.no_grad():
return self.model.predict(image)
def validation_step(self, batch, batch_idx):
results = self.shared_step(batch, stage='validation')
return results
def validation_epoch_end(self, outputs):
self.shared_epoch_end(outputs, stage='validation')
def shared_epoch_end(self, outputs, stage):
confusion_matrix = sum([x['confusion_matrix'] for x in outputs])
avg_loss = torch.tensor([x['loss'] for x in outputs]).mean()
iou_dict = get_iou(confusion_matrix)
f1_dict = get_f1_scores(confusion_matrix)
results = {
f'{stage}_loss': avg_loss,
f'{stage}_mIOU': iou_dict['mIOU'],
f'{stage}_macro_f1': f1_dict['macro_f1'],
f'{stage}_weighted_f1': f1_dict['weighted_f1']
}
self.log_dict(results, sync_dist=True)
return results
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.973435286)
scheduler = {
'scheduler': scheduler,
'interval': 'epoch'
}
return [optimizer], [scheduler]
return [optimizer]