зеркало из https://github.com/microsoft/archai.git
fix(task/seg): training script changes for compat w/ newer version of lightning
This commit is contained in:
Родитель
f5170158b6
Коммит
df190bfea0
|
@ -49,7 +49,7 @@ if __name__ == '__main__':
|
||||||
]
|
]
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
default_root_dir=str(args.output_dir), gpus=1,
|
default_root_dir=str(args.output_dir), accelerator='gpu',
|
||||||
val_check_interval=args.val_check_interval,
|
val_check_interval=args.val_check_interval,
|
||||||
max_epochs=args.epochs,
|
max_epochs=args.epochs,
|
||||||
callbacks=callbacks
|
callbacks=callbacks
|
||||||
|
|
|
@ -40,52 +40,32 @@ class SegmentationTrainingLoop(pl.LightningModule):
|
||||||
pred_classes.cpu(), mask.cpu(), self.num_classes, self.ignore_mask_value
|
pred_classes.cpu(), mask.cpu(), self.num_classes, self.ignore_mask_value
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
iou_dict = get_iou(confusion_matrix)
|
||||||
'loss': loss,
|
f1_dict = get_f1_scores(confusion_matrix)
|
||||||
'confusion_matrix': confusion_matrix
|
|
||||||
|
results = {
|
||||||
|
f'{stage}_loss': loss,
|
||||||
|
f'{stage}_mIOU': iou_dict['mIOU'],
|
||||||
|
f'{stage}_macro_f1': f1_dict['macro_f1'],
|
||||||
|
f'{stage}_weighted_f1': f1_dict['weighted_f1']
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
results = self.shared_step(batch, stage='train')
|
results = self.shared_step(batch, stage='train')
|
||||||
self.log_dict({'training_loss': results['loss']}, sync_dist=True)
|
self.log_dict(results, sync_dist=True, on_step=True, on_epoch=True)
|
||||||
|
return results['train_loss']
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
results = self.shared_step(batch, stage='validation')
|
||||||
|
self.log_dict(results, sync_dist=True, on_step=False, on_epoch=True)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def predict(self, image):
|
def predict(self, image):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return self.model.predict(image)
|
return self.model.predict(image)
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
results = self.shared_step(batch, stage='validation')
|
|
||||||
return results
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
|
||||||
self.shared_epoch_end(outputs, stage='validation')
|
|
||||||
|
|
||||||
def shared_epoch_end(self, outputs, stage):
|
|
||||||
confusion_matrix = sum([x['confusion_matrix'] for x in outputs])
|
|
||||||
avg_loss = torch.tensor([x['loss'] for x in outputs]).mean()
|
|
||||||
|
|
||||||
iou_dict = get_iou(confusion_matrix)
|
|
||||||
f1_dict = get_f1_scores(confusion_matrix)
|
|
||||||
|
|
||||||
results = {
|
|
||||||
f'{stage}_loss': avg_loss,
|
|
||||||
f'{stage}_mIOU': iou_dict['mIOU'],
|
|
||||||
f'{stage}_macro_f1': f1_dict['macro_f1'],
|
|
||||||
f'{stage}_weighted_f1': f1_dict['weighted_f1']
|
|
||||||
}
|
|
||||||
|
|
||||||
self.log_dict(results, sync_dist=True)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.973435286)
|
return [optimizer]
|
||||||
|
|
||||||
scheduler = {
|
|
||||||
'scheduler': scheduler,
|
|
||||||
'interval': 'epoch'
|
|
||||||
}
|
|
||||||
|
|
||||||
return [optimizer], [scheduler]
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче