From df190bfea002bc7ca0528fb44522c5310291e063 Mon Sep 17 00:00:00 2001 From: piero2c Date: Thu, 30 Mar 2023 11:15:20 -0700 Subject: [PATCH] fix(task/seg): training script changes for compat w/ newer version of lightning --- tasks/face_segmentation/train.py | 2 +- .../face_segmentation/training/pl_trainer.py | 52 ++++++------------- 2 files changed, 17 insertions(+), 37 deletions(-) diff --git a/tasks/face_segmentation/train.py b/tasks/face_segmentation/train.py index 9eaf5f30..846fe337 100644 --- a/tasks/face_segmentation/train.py +++ b/tasks/face_segmentation/train.py @@ -49,7 +49,7 @@ if __name__ == '__main__': ] trainer = Trainer( - default_root_dir=str(args.output_dir), gpus=1, + default_root_dir=str(args.output_dir), accelerator='gpu', val_check_interval=args.val_check_interval, max_epochs=args.epochs, callbacks=callbacks diff --git a/tasks/face_segmentation/training/pl_trainer.py b/tasks/face_segmentation/training/pl_trainer.py index bd9dee1d..2a88eff4 100644 --- a/tasks/face_segmentation/training/pl_trainer.py +++ b/tasks/face_segmentation/training/pl_trainer.py @@ -40,52 +40,32 @@ class SegmentationTrainingLoop(pl.LightningModule): pred_classes.cpu(), mask.cpu(), self.num_classes, self.ignore_mask_value ) - return { - 'loss': loss, - 'confusion_matrix': confusion_matrix + iou_dict = get_iou(confusion_matrix) + f1_dict = get_f1_scores(confusion_matrix) + + results = { + f'{stage}_loss': loss, + f'{stage}_mIOU': iou_dict['mIOU'], + f'{stage}_macro_f1': f1_dict['macro_f1'], + f'{stage}_weighted_f1': f1_dict['weighted_f1'] } + return results + def training_step(self, batch, batch_idx): results = self.shared_step(batch, stage='train') - self.log_dict({'training_loss': results['loss']}, sync_dist=True) + self.log_dict(results, sync_dist=True, on_step=True, on_epoch=True) + return results['train_loss'] + def validation_step(self, batch, batch_idx): + results = self.shared_step(batch, stage='validation') + self.log_dict(results, sync_dist=True, on_step=False, on_epoch=True) return results def predict(self, image): with torch.no_grad(): return self.model.predict(image) - def validation_step(self, batch, batch_idx): - results = self.shared_step(batch, stage='validation') - return results - - def validation_epoch_end(self, outputs): - self.shared_epoch_end(outputs, stage='validation') - - def shared_epoch_end(self, outputs, stage): - confusion_matrix = sum([x['confusion_matrix'] for x in outputs]) - avg_loss = torch.tensor([x['loss'] for x in outputs]).mean() - - iou_dict = get_iou(confusion_matrix) - f1_dict = get_f1_scores(confusion_matrix) - - results = { - f'{stage}_loss': avg_loss, - f'{stage}_mIOU': iou_dict['mIOU'], - f'{stage}_macro_f1': f1_dict['macro_f1'], - f'{stage}_weighted_f1': f1_dict['weighted_f1'] - } - - self.log_dict(results, sync_dist=True) - return results - def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.973435286) - - scheduler = { - 'scheduler': scheduler, - 'interval': 'epoch' - } - - return [optimizer], [scheduler] + return [optimizer]