ENH: Add a class_weights argument to the data module (#871)

This commit is contained in:
Anton Schwaighofer 2023-04-17 20:48:32 +01:00 коммит произвёл GitHub
Родитель 8dbd3249cc
Коммит 5746042f68
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 22 добавлений и 2 удалений

Просмотреть файл

@ -25,7 +25,7 @@ from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from monai.transforms import Compose
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
_SlidesOrTilesDataset = TypeVar("_SlidesOrTilesDataset", SlidesDataset, TilesDataset)
class CacheMode(Enum):
@ -57,6 +57,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
pl_replace_sampler_ddp: bool = True,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
dataframe_kwargs: Optional[Dict[str, Any]] = None,
class_weights: Optional[torch.Tensor] = None,
) -> None:
"""
:param root_path: Root directory of the source dataset.
@ -80,6 +81,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
:param class_weights: Class weights to use for the dataset. If None, will compute them from the dataset.
"""
batch_size_inf = batch_size_inf or batch_size
@ -97,7 +99,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
self.test_dataset: _SlidesOrTilesDataset
self.dataframe_kwargs = dataframe_kwargs or {}
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
self.class_weights = self.train_dataset.get_class_weights()
self.class_weights = class_weights if class_weights is not None else self.train_dataset.get_class_weights()
self.seed = seed
self.dataloader_kwargs = dataloader_kwargs or {}

Просмотреть файл

@ -182,3 +182,21 @@ def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
loading_params=get_loading_params(),
)
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)
def test_histo_module_class_weights(mock_panda_slides_root_dir: Path) -> None:
"""Test if the class weights argument of the HistoDataModule is correctly set."""
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
tiling_params=TilingParams(),
loading_params=get_loading_params(),
)
assert datamodule.class_weights.shape != (0,)
new_class_weights = torch.tensor([])
datamodule = PandaSlidesDataModule(
root_path=mock_panda_slides_root_dir,
tiling_params=TilingParams(),
loading_params=get_loading_params(),
class_weights=new_class_weights,
)
assert datamodule.class_weights.shape == new_class_weights.shape