зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add a class_weights argument to the data module (#871)
This commit is contained in:
Родитель
8dbd3249cc
Коммит
5746042f68
|
@ -25,7 +25,7 @@ from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
|
|||
from monai.transforms import Compose
|
||||
|
||||
|
||||
_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
|
||||
_SlidesOrTilesDataset = TypeVar("_SlidesOrTilesDataset", SlidesDataset, TilesDataset)
|
||||
|
||||
|
||||
class CacheMode(Enum):
|
||||
|
@ -57,6 +57,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
pl_replace_sampler_ddp: bool = True,
|
||||
dataloader_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dataframe_kwargs: Optional[Dict[str, Any]] = None,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param root_path: Root directory of the source dataset.
|
||||
|
@ -80,6 +81,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
:param pl_replace_sampler_ddp: If True, replace the sampler with a DistributedSampler when using DDP.
|
||||
:param dataloader_kwargs: Additional keyword arguments for the training, validation, and test dataloaders.
|
||||
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
|
||||
:param class_weights: Class weights to use for the dataset. If None, will compute them from the dataset.
|
||||
"""
|
||||
|
||||
batch_size_inf = batch_size_inf or batch_size
|
||||
|
@ -97,7 +99,7 @@ class HistoDataModule(LightningDataModule, Generic[_SlidesOrTilesDataset]):
|
|||
self.test_dataset: _SlidesOrTilesDataset
|
||||
self.dataframe_kwargs = dataframe_kwargs or {}
|
||||
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
|
||||
self.class_weights = self.train_dataset.get_class_weights()
|
||||
self.class_weights = class_weights if class_weights is not None else self.train_dataset.get_class_weights()
|
||||
self.seed = seed
|
||||
self.dataloader_kwargs = dataloader_kwargs or {}
|
||||
|
||||
|
|
|
@ -182,3 +182,21 @@ def test_assertion_error_missing_seed(mock_panda_slides_root_dir: Path) -> None:
|
|||
loading_params=get_loading_params(),
|
||||
)
|
||||
slides_datamodule._get_ddp_sampler(MagicMock(), ModelKey.TRAIN)
|
||||
|
||||
|
||||
def test_histo_module_class_weights(mock_panda_slides_root_dir: Path) -> None:
|
||||
"""Test if the class weights argument of the HistoDataModule is correctly set."""
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
tiling_params=TilingParams(),
|
||||
loading_params=get_loading_params(),
|
||||
)
|
||||
assert datamodule.class_weights.shape != (0,)
|
||||
new_class_weights = torch.tensor([])
|
||||
datamodule = PandaSlidesDataModule(
|
||||
root_path=mock_panda_slides_root_dir,
|
||||
tiling_params=TilingParams(),
|
||||
loading_params=get_loading_params(),
|
||||
class_weights=new_class_weights,
|
||||
)
|
||||
assert datamodule.class_weights.shape == new_class_weights.shape
|
||||
|
|
Загрузка…
Ссылка в новой задаче