ENH: Add command line arg for custom SSL checkpoint (#560)

Add command line arg for custom SSL checkpoint
This commit is contained in:
Melissa Bristow 2022-08-09 08:25:49 +01:00 коммит произвёл GitHub
Родитель 5060e28ea4
Коммит 6c9e4c9d17
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 32 добавлений и 14 удалений

16
.github/workflows/cpath-pr.yml поставляемый
Просмотреть файл

@ -151,6 +151,22 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tilespandaimagenetmil_aml
smoke_test_tcgacrcksslmil_aml:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true
- name: Set up smoke test environment
id: setup-sslmil-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment
- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_tcgacrcksslmil_aml
smoke_test_crck_simclr_aml:
runs-on: ubuntu-20.04
steps:

Просмотреть файл

@ -77,6 +77,8 @@ pytest:
pytest_coverage:
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
cd ../ ; \
@ -92,7 +94,7 @@ define DEEPSMILEPANDATILES_ARGS
endef
define TCGACRCKSSLMIL_ARGS
--model=health_cpath.TcgaCrckSSLMIL
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
endef
define CRCKSIMCLR_ARGS
@ -193,9 +195,6 @@ smoke_test_tilespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
# Note: this test doesn't currently run in hi-ml Workspace since the checkpoint run specified in run_ids
# innereye_ssl_checkpoint_crck_4ws does not exist there. Once we can specify alternative checkpoints
# this can be run with any Workspace
# The following test takes about 30 seconds
smoke_test_tcgacrcksslmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \

Просмотреть файл

@ -79,6 +79,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
default=False,
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
"initiliaze classifier with random weights.")
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
"using SSLEncoder")
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@ -108,9 +110,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
def cache_dir(self) -> Path:
return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/")
def setup(self) -> None:
self.ssl_ckpt_run_id = ""
def get_test_plot_options(self) -> Set[PlotOption]:
options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
if self.num_top_slides > 0:
@ -232,7 +231,7 @@ class BaseMILTiles(BaseMIL):
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
if self.is_caching:
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_ckpt_run_id,
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
self.outputs_folder)
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
@ -254,7 +253,7 @@ class BaseMILTiles(BaseMIL):
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
@ -297,7 +296,7 @@ class BaseMILSlides(BaseMIL):
pretrained_classifier=self.pretrained_classifier,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),

Просмотреть файл

@ -55,7 +55,8 @@ class DeepSMILECrck(BaseMILTiles):
def setup(self) -> None:
super().setup()
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_crck_4ws
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
def get_data_module(self) -> TilesDataModule:
return TcgaCrckTilesDataModule(

Просмотреть файл

@ -70,7 +70,8 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
def setup(self) -> None:
BaseMILTiles.setup(self)
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_data_module(self) -> PandaTilesDataModule:
return PandaTilesDataModule(
@ -135,7 +136,8 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
def setup(self) -> None:
BaseMILSlides.setup(self)
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
def get_dataloader_kwargs(self) -> dict:
return dict(

Просмотреть файл

@ -127,7 +127,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),

Просмотреть файл

@ -37,6 +37,7 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
# declared in TrainerParams:
max_epochs=2,
crossval_count=1,
ssl_checkpoint_run_id=""
)
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)