BUG: Retrieve tiles_ids from outputs dictionary for loss analysis (#641)

Slides pipeline don't have tile_ids in the batch, we instead retrieve
them from results
This commit is contained in:
Kenza Bouzid 2022-10-25 12:59:17 +01:00 коммит произвёл GitHub
Родитель b2e873daa5
Коммит c156faddb9
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 31 добавлений и 4 удалений

17
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -231,6 +231,22 @@ jobs:
cd ${{ env.folder }}
make smoke_test_crck_loss_analysis_aml
smoke_test_slides_panda_loss_analysis:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Set up smoke test environment
id: setup-finetuning-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_slides_panda_loss_analysis_aml
publish:
runs-on: ubuntu-20.04
needs: [
@ -242,6 +258,7 @@ jobs:
smoke_test_crck_simclr_aml,
smoke_test_crck_flexible_finetuning_aml,
smoke_test_crck_loss_analysis,
smoke_test_slides_panda_loss_analysis,
]
steps:
- uses: actions/checkout@v3

Просмотреть файл

@ -247,6 +247,14 @@ smoke_test_crck_loss_analysis_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${TCGACRCKSSLMIL_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS} ${LOSS_ANALYSIS_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local
smoke_test_slides_panda_loss_analysis_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS};}
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml
smoke_test_slides_panda_loss_analysis_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDASLIDES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEDEFAULT_SMOKE_TEST_ARGS} ${LOSS_ANALYSIS_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
smoke tests local: smoke_test_slidespandaimagenetmil_local smoke_test_tilespandaimagenetmil_local smoke_test_tcgacrcksslmil_local smoke_test_crck_simclr_local smoke_test_crck_flexible_finetuning_local smoke_test_tcgacrckimagenetmil_local smoke_test_crck_loss_analysis_local smoke_test_slides_panda_loss_analysis_local
smoke tests AML: smoke_test_slidespandaimagenetmil_aml smoke_test_tilespandaimagenetmil_aml smoke_test_tcgacrcksslmil_aml smoke_test_crck_simclr_aml smoke_test_crck_flexible_finetuning_aml smoke_test_tcgacrckimagenetmil_aml smoke_test_crck_loss_analysis_aml smoke_test_slides_panda_loss_analysis_aml

Просмотреть файл

@ -132,6 +132,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
outputs_handler=outputs_handler,
analyse_loss=self.analyse_loss,
)
outputs_handler.set_slides_dataset_for_plots_handlers(self.get_slides_dataset())
return deepmil_module

Просмотреть файл

@ -360,7 +360,8 @@ class BaseDeepMILModule(LightningModule):
results = {ResultsKey.LOSS: train_result[ResultsKey.LOSS]}
if self.analyse_loss:
results.update({ResultsKey.LOSS_PER_SAMPLE: train_result[ResultsKey.LOSS_PER_SAMPLE],
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS]})
ResultsKey.CLASS_PROBS: train_result[ResultsKey.CLASS_PROBS],
ResultsKey.TILE_ID: train_result[ResultsKey.TILE_ID]})
return results
def validation_step(self, batch: Dict, batch_idx: int) -> BatchResultsType: # type: ignore

Просмотреть файл

@ -440,7 +440,7 @@ class LossAnalysisCallback(Callback):
self.loss_cache[stage][ResultsKey.ENTROPY].extend(self.compute_entropy(outputs[ResultsKey.CLASS_PROBS]))
if self.save_tile_ids:
self.loss_cache[stage][ResultsKey.TILE_ID].extend(
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in batch[ResultsKey.TILE_ID]]
[self.TILES_JOIN_TOKEN.join(tiles) for tiles in outputs[ResultsKey.TILE_ID]]
)
def synchronise_processes_and_reset(self, trainer: Trainer, pl_module: BaseDeepMILModule, stage: ModelKey) -> None: