зеркало из https://github.com/microsoft/hi-ml.git
BUG: Retrieve tiles_ids from outputs dictionary for loss analysis (#641)
Slides pipeline don't have tile_ids in the batch, we instead retrieve them from results
This commit is contained in:
Родитель
b2e873daa5
Коммит
c156faddb9
|
@ -231,6 +231,22 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_crck_loss_analysis_aml
|
||||
|
||||
smoke_test_slides_panda_loss_analysis:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-finetuning-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_slides_panda_loss_analysis_aml
|
||||
|
||||
publish:
|
||||
runs-on: ubuntu-20.04
|
||||
needs: [
|
||||
|
@ -242,6 +258,7 @@ jobs:
|
|||
smoke_test_crck_simclr_aml,
|
||||
smoke_test_crck_flexible_finetuning_aml,
|
||||
smoke_test_crck_loss_analysis,
|
||||
smoke_test_slides_panda_loss_analysis,
|
||||
]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
|
|
@ -247,6 +247,14 @@ smoke_test_crck_loss_analysis_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${LOSS_ANALYSIS_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local
|
||||
smoke_test_slides_panda_loss_analysis_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml
|
||||
smoke_test_slides_panda_loss_analysis_aml:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local
|
||||
|
||||
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml
|
||||
|
|
|
@ -132,6 +132,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
outputs_handler=outputs_handler,
|
||||
analyse_loss=self.analyse_loss,
|
||||
)
|
||||
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
|
||||
return deepmil_module
|
||||
|
|
|
@ -360,7 +360,8 @@ class BaseDeepMILModule(LightningModule):
|
|||
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
|
||||
if self.analyse_loss:
|
||||
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
|
||||
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS]})
|
||||
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS],
|
||||
ResultsKey.TILE_ID: train_result[ResultsKey.TILE_ID]})
|
||||
return results
|
||||
|
||||
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore
|
||||
|
|
|
@ -440,7 +440,7 @@ class LossAnalysisCallback(Callback):
|
|||
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
|
||||
if self.save_tile_ids:
|
||||
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
|
||||
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in batch[ResultsKey.TILE_ID]]
|
||||
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in outputs[ResultsKey.TILE_ID]]
|
||||
)
|
||||
|
||||
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче