diff --git a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py index 3289b074..a002e637 100644 --- a/hi-ml-cpath/src/health_cpath/datamodules/base_module.py +++ b/hi-ml-cpath/src/health_cpath/datamodules/base_module.py @@ -25,7 +25,7 @@ from monai.data.dataset import CacheDataset, Dataset, PersistentDataset from monai.transforms import Compose -_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset) +_SlidesOrTilesDataset = TypeVar("_SlidesOrTilesDataset", SlidesDataset, TilesDataset) class CacheMode(Enum): @@ -57,6 +57,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]): pl_replace_sampler_ddp: bool = True, dataloader_kwargs: Optional[Dict[str, Any]] = None, dataframe_kwargs: Optional[Dict[str, Any]] = None, + class_weights: Optional[torch.Tensor] = None, ) -> None: """ :param root_path: Root directory of the source dataset. @@ -80,6 +81,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]): :param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP. :param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders. :param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV. + :param class_weights: Class weights to use for the dataset. If None, will compute them from the dataset. """ batch_size_inf = batch_size_inf or batch_size @@ -97,7 +99,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]): self.test_dataset: _SlidesOrTilesDataset self.dataframe_kwargs = dataframe_kwargs or {} self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits() - self.class_weights = self.train_dataset.get_class_weights() + self.class_weights = class_weights if class_weights is not None else self.train_dataset.get_class_weights() self.seed = seed self.dataloader_kwargs = dataloader_kwargs or {} diff --git a/hi-ml-cpath/testhisto/testhisto/datamodules/test_histo_datamodule.py b/hi-ml-cpath/testhisto/testhisto/datamodules/test_histo_datamodule.py index dd16b4fe..3f996267 100644 --- a/hi-ml-cpath/testhisto/testhisto/datamodules/test_histo_datamodule.py +++ b/hi-ml-cpath/testhisto/testhisto/datamodules/test_histo_datamodule.py @@ -182,3 +182,21 @@ def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None: loading_params=get_loading_params(), ) slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN) + + +def test_histo_module_class_weights(mock_panda_slides_root_dir: Path) -> None: + """Test if the class weights argument of the HistoDataModule is correctly set.""" + datamodule = PandaSlidesDataModule( + root_path=mock_panda_slides_root_dir, + tiling_params=TilingParams(), + loading_params=get_loading_params(), + ) + assert datamodule.class_weights.shape != (0,) + new_class_weights = torch.tensor([]) + datamodule = PandaSlidesDataModule( + root_path=mock_panda_slides_root_dir, + tiling_params=TilingParams(), + loading_params=get_loading_params(), + class_weights=new_class_weights, + ) + assert datamodule.class_weights.shape == new_class_weights.shape