зеркало из https://github.com/microsoft/archai.git
fix(task/seg): training script changes for compat w/ newer version of lightning
This commit is contained in:
Родитель
f5170158b6
Коммит
df190bfea0
|
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
|||
]
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=str(args.output_dir), gpus=1,
|
||||
default_root_dir=str(args.output_dir), accelerator='gpu',
|
||||
val_check_interval=args.val_check_interval,
|
||||
max_epochs=args.epochs,
|
||||
callbacks=callbacks
|
||||
|
|
|
@ -40,52 +40,32 @@ class SegmentationTrainingLoop(pl.LightningModule):
|
|||
pred_classes.cpu(), mask.cpu(), self.num_classes, self.ignore_mask_value
|
||||
)
|
||||
|
||||
return {
|
||||
'loss': loss,
|
||||
'confusion_matrix': confusion_matrix
|
||||
iou_dict = get_iou(confusion_matrix)
|
||||
f1_dict = get_f1_scores(confusion_matrix)
|
||||
|
||||
results = {
|
||||
f'{stage}_loss': loss,
|
||||
f'{stage}_mIOU': iou_dict['mIOU'],
|
||||
f'{stage}_macro_f1': f1_dict['macro_f1'],
|
||||
f'{stage}_weighted_f1': f1_dict['weighted_f1']
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
results = self.shared_step(batch, stage='train')
|
||||
self.log_dict({'training_loss': results['loss']}, sync_dist=True)
|
||||
self.log_dict(results, sync_dist=True, on_step=True, on_epoch=True)
|
||||
return results['train_loss']
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
results = self.shared_step(batch, stage='validation')
|
||||
self.log_dict(results, sync_dist=True, on_step=False, on_epoch=True)
|
||||
return results
|
||||
|
||||
def predict(self, image):
|
||||
with torch.no_grad():
|
||||
return self.model.predict(image)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
results = self.shared_step(batch, stage='validation')
|
||||
return results
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.shared_epoch_end(outputs, stage='validation')
|
||||
|
||||
def shared_epoch_end(self, outputs, stage):
|
||||
confusion_matrix = sum([x['confusion_matrix'] for x in outputs])
|
||||
avg_loss = torch.tensor([x['loss'] for x in outputs]).mean()
|
||||
|
||||
iou_dict = get_iou(confusion_matrix)
|
||||
f1_dict = get_f1_scores(confusion_matrix)
|
||||
|
||||
results = {
|
||||
f'{stage}_loss': avg_loss,
|
||||
f'{stage}_mIOU': iou_dict['mIOU'],
|
||||
f'{stage}_macro_f1': f1_dict['macro_f1'],
|
||||
f'{stage}_weighted_f1': f1_dict['weighted_f1']
|
||||
}
|
||||
|
||||
self.log_dict(results, sync_dist=True)
|
||||
return results
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.973435286)
|
||||
|
||||
scheduler = {
|
||||
'scheduler': scheduler,
|
||||
'interval': 'epoch'
|
||||
}
|
||||
|
||||
return [optimizer], [scheduler]
|
||||
return [optimizer]
|
||||
|
|
Загрузка…
Ссылка в новой задаче