зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add command line arg for custom SSL checkpoint (#560)
Add command line arg for custom SSL checkpoint
This commit is contained in:
Родитель
5060e28ea4
Коммит
6c9e4c9d17
|
@ -151,6 +151,22 @@ jobs:
|
|||
cd ${{ env.folder }}
|
||||
make smoke_test_tilespandaimagenetmil_aml
|
||||
|
||||
smoke_test_tcgacrcksslmil_aml:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: true
|
||||
|
||||
- name: Set up smoke test environment
|
||||
id: setup-sslmil-smoke-test-environment
|
||||
uses: ./.github/actions/prepare_smoke_test_environment
|
||||
|
||||
- name: smoke test
|
||||
run: |
|
||||
cd ${{ env.folder }}
|
||||
make smoke_test_tcgacrcksslmil_aml
|
||||
|
||||
smoke_test_crck_simclr_aml:
|
||||
runs-on: ubuntu-20.04
|
||||
steps:
|
||||
|
|
|
@ -77,6 +77,8 @@ pytest:
|
|||
pytest_coverage:
|
||||
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc
|
||||
|
||||
SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606
|
||||
|
||||
# Run regression tests and compare performance
|
||||
define BASE_CPATH_RUNNER_COMMAND
|
||||
cd ../ ; \
|
||||
|
@ -92,7 +94,7 @@ define DEEPSMILEPANDATILES_ARGS
|
|||
endef
|
||||
|
||||
define TCGACRCKSSLMIL_ARGS
|
||||
--model=health_cpath.TcgaCrckSSLMIL
|
||||
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
|
||||
endef
|
||||
|
||||
define CRCKSIMCLR_ARGS
|
||||
|
@ -193,9 +195,6 @@ smoke_test_tilespandaimagenetmil_aml:
|
|||
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}
|
||||
|
||||
# Note: this test doesn't currently run in hi-ml Workspace since the checkpoint run specified in run_ids
|
||||
# innereye_ssl_checkpoint_crck_4ws does not exist there. Once we can specify alternative checkpoints
|
||||
# this can be run with any Workspace
|
||||
# The following test takes about 30 seconds
|
||||
smoke_test_tcgacrcksslmil_local:
|
||||
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
|
||||
|
|
|
@ -79,6 +79,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
default=False,
|
||||
doc="If True, will use classifier weights from pretrained model specified in src_checkpoint. If False, will "
|
||||
"initiliaze classifier with random weights.")
|
||||
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
|
||||
"using SSLEncoder")
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
@ -108,9 +110,6 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
|
|||
def cache_dir(self) -> Path:
|
||||
return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/")
|
||||
|
||||
def setup(self) -> None:
|
||||
self.ssl_ckpt_run_id = ""
|
||||
|
||||
def get_test_plot_options(self) -> Set[PlotOption]:
|
||||
options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
|
||||
if self.num_top_slides > 0:
|
||||
|
@ -232,7 +231,7 @@ class BaseMILTiles(BaseMIL):
|
|||
|
||||
def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
|
||||
if self.is_caching:
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_ckpt_run_id,
|
||||
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
|
||||
self.outputs_folder)
|
||||
transform = Compose([
|
||||
LoadTilesBatchd(image_key, progress=True),
|
||||
|
@ -254,7 +253,7 @@ class BaseMILTiles(BaseMIL):
|
|||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
@ -297,7 +296,7 @@ class BaseMILSlides(BaseMIL):
|
|||
pretrained_classifier=self.pretrained_classifier,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
|
|
@ -55,7 +55,8 @@ class DeepSMILECrck(BaseMILTiles):
|
|||
|
||||
def setup(self) -> None:
|
||||
super().setup()
|
||||
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_crck_4ws
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws
|
||||
|
||||
def get_data_module(self) -> TilesDataModule:
|
||||
return TcgaCrckTilesDataModule(
|
||||
|
|
|
@ -70,7 +70,8 @@ class DeepSMILETilesPanda(BaseMILTiles, BaseDeepSMILEPanda):
|
|||
|
||||
def setup(self) -> None:
|
||||
BaseMILTiles.setup(self)
|
||||
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_data_module(self) -> PandaTilesDataModule:
|
||||
return PandaTilesDataModule(
|
||||
|
@ -135,7 +136,8 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
|
|||
|
||||
def setup(self) -> None:
|
||||
BaseMILSlides.setup(self)
|
||||
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
|
||||
# If no SSL checkpoint is provided, use the default one
|
||||
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary
|
||||
|
||||
def get_dataloader_kwargs(self) -> dict:
|
||||
return dict(
|
||||
|
|
|
@ -127,7 +127,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
|
|||
class_weights=self.data_module.class_weights,
|
||||
dropout_rate=self.dropout_rate,
|
||||
outputs_folder=self.outputs_folder,
|
||||
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
|
||||
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
|
||||
encoder_params=create_from_matching_params(self, EncoderParams),
|
||||
pooling_params=create_from_matching_params(self, PoolingParams),
|
||||
optimizer_params=create_from_matching_params(self, OptimizerParams),
|
||||
|
|
|
@ -37,6 +37,7 @@ class MockDeepSMILETilesPanda(DeepSMILETilesPanda):
|
|||
# declared in TrainerParams:
|
||||
max_epochs=2,
|
||||
crossval_count=1,
|
||||
ssl_checkpoint_run_id=""
|
||||
)
|
||||
default_kwargs.update(kwargs)
|
||||
super().__init__(**default_kwargs)
|
||||
|
|
Загрузка…
Ссылка в новой задаче