ENH: Clean up design of slides dataset (#875)

Metadata columns are presently set via class variables, rather than in
the constructor. Hence, we can't have several datasets that use
different metadata columns, without them interfering with each other.
This caused odd test failures because test results now depend on the
order in which tests are executed.
This commit is contained in:
Anton Schwaighofer 2023-04-20 21:51:40 +01:00 коммит произвёл GitHub
Родитель da5547c726
Коммит abf36f5d2e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 112 добавлений и 83 удалений

Просмотреть файл

@ -153,7 +153,7 @@ class DeepSMILESlidesPanda(BaseMILSlides, BaseDeepSMILEPanda):
tiling_params=create_from_matching_params(self, TilingParams),
loading_params=create_from_matching_params(self, LoadingParams),
seed=self.get_effective_random_seed(),
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
transforms_dict=self.get_transforms_dict(SlideKey.IMAGE),
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),

Просмотреть файл

@ -15,7 +15,6 @@ from health_cpath.utils.wsi_utils import TilingParams
from health_ml.networks.layers.attention_layers import TransformerPooling, TransformerPoolingBenchmark
from health_ml.utils.checkpoint_utils import CheckpointParser
from health_ml.deep_learning_config import OptimizerParams
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datamodules.panda_module_benchmark import PandaSlidesDataModuleBenchmark
from health_cpath.models.encoders import (
HistoSSLEncoder,
@ -116,7 +115,7 @@ class DeepSMILESlidesPandaBenchmark(DeepSMILESlidesPanda):
tiling_params=create_from_matching_params(self, TilingParams),
loading_params=create_from_matching_params(self, LoadingParams),
seed=self.get_effective_random_seed(),
transforms_dict=self.get_transforms_dict(PandaDataset.IMAGE_COLUMN),
transforms_dict=self.get_transforms_dict(SlideKey.IMAGE),
crossval_count=self.crossval_count,
crossval_index=self.crossval_index,
dataloader_kwargs=self.get_dataloader_kwargs(),

Просмотреть файл

@ -50,7 +50,7 @@ class PandaSlidesDataModule(SlidesDataModule):
proportion_train=0.8,
proportion_test=0.1,
proportion_val=0.1,
subject_column=dataset.SLIDE_ID_COLUMN,
subject_column=dataset.slide_id_column,
)
if self.crossval_count > 1:

Просмотреть файл

@ -28,7 +28,7 @@ class PandaSlidesDataModuleBenchmark(SlidesDataModule):
proportion_train=0.8,
proportion_test=0.0,
proportion_val=0.2,
subject_column=dataset.SLIDE_ID_COLUMN,
subject_column=dataset.slide_id_column,
)
if self.crossval_count > 1:

Просмотреть файл

@ -150,26 +150,15 @@ class TilesDataset(Dataset):
self.dataset_df = self.dataset_df.assign(**{TilesDataset.TILE_Y_COLUMN: self.dataset_df[TileKey.TILE_TOP]})
DEFAULT_DATASET_CSV = "dataset.csv"
class SlidesDataset(Dataset):
"""Base class for datasets of WSIs, iterating dictionaries of image paths and metadata.
The output dictionaries are indexed by `..utils.naming.SlideKey`.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
"""
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
MASK_COLUMN: Optional[str] = None
SPLIT_COLUMN: Optional[str] = None
METADATA_COLUMNS: Tuple[str, ...] = ()
DEFAULT_CSV_FILENAME: str = "dataset.csv"
def __init__(
self,
root: Union[str, Path],
@ -180,44 +169,60 @@ class SlidesDataset(Dataset):
label_column: str = DEFAULT_LABEL_COLUMN,
n_classes: int = 1,
dataframe_kwargs: Dict[str, Any] = {},
default_csv_filename: str = DEFAULT_DATASET_CSV,
slide_id_column: str = "slide_id",
image_column: str = "image",
mask_column: Optional[str] = None,
split_column: Optional[str] = None,
metadata_columns: Tuple[str, ...] = (),
) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
`slide_id_column` and `image_column`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
default (`None`), loads the entire dataset as-is.
:param validate_columns: Whether to call `validate_columns()` at the end of `__init__()`.
`validate_columns()` checks that the loaded data frame for the dataset contains the expected column names
for this class
`validate_columns()` checks that the loaded data frame for the dataset contains the expected column names
for this class
:param label_column: CSV column name for tile label. Default is `DEFAULT_LABEL_COLUMN="label"`.
:param n_classes: Number of classes indexed in `label_column`. Default is 1 for binary classification.
:param dataframe_kwargs: Keyword arguments to pass to `pd.read_csv()` when loading the dataset CSV.
:param slide_id_column: CSV column name for slide ID. Default is `slide_id`.
:param image_column: CSV column name for relative path to image file. Default is `image`.
:param mask_column: CSV column name for relative path to mask file. Default is `None`.
:param split_column: CSV column name for train/test split. Default is `None`.
:param default_csv_filename: Default name of the dataset CSV at the dataset root directory.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
self.root_dir = Path(root)
self.label_column = label_column
self.n_classes = n_classes
self.dataframe_kwargs = dataframe_kwargs
self.slide_id_column = slide_id_column
self.image_column = image_column
self.mask_column = mask_column
self.split_column = split_column
self.metadata_columns = metadata_columns
self.default_csv_filename = default_csv_filename
if self.split_column is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")
if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
self.dataset_csv = dataset_csv or self.root_dir / self.default_csv_filename
dataset_df = pd.read_csv(self.dataset_csv, **self.dataframe_kwargs)
if dataset_df.index.name != self.SLIDE_ID_COLUMN:
dataset_df = dataset_df.set_index(self.SLIDE_ID_COLUMN)
if dataset_df.index.name != self.slide_id_column:
dataset_df = dataset_df.set_index(self.slide_id_column)
if train is None:
self.dataset_df = dataset_df
else:
split = DEFAULT_TRAIN_SPLIT_LABEL if train else DEFAULT_TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]
self.dataset_df = dataset_df[dataset_df[self.split_column] == split]
if validate_columns:
self.validate_columns()
@ -228,14 +233,14 @@ class SlidesDataset(Dataset):
If the constructor is overloaded in a subclass, you can pass `validate_columns=False` and
call `validate_columns()` after creating derived columns, for example.
"""
mandatory_columns = {self.IMAGE_COLUMN, self.label_column, self.MASK_COLUMN, self.SPLIT_COLUMN}
mandatory_columns = {self.image_column, self.label_column, self.mask_column, self.split_column}
optional_columns = (
set(self.dataframe_kwargs["usecols"]) if "usecols" in self.dataframe_kwargs else set(self.METADATA_COLUMNS)
set(self.dataframe_kwargs["usecols"]) if "usecols" in self.dataframe_kwargs else set(self.metadata_columns)
)
columns = mandatory_columns.union(optional_columns)
# SLIDE_ID_COLUMN is used for indexing and is not in df.columns anymore
# None might be in columns if SPLITS_COLUMN is None
columns_not_found = columns - set(self.dataset_df.columns) - {None, self.SLIDE_ID_COLUMN}
# slide_id_column is used for indexing and is not in df.columns anymore
# None might be in columns if split_column is None
columns_not_found = columns - set(self.dataset_df.columns) - {None, self.slide_id_column}
if len(columns_not_found) > 0:
raise ValueError(f"Expected columns '{columns_not_found}' not found in the dataframe")
@ -247,18 +252,18 @@ class SlidesDataset(Dataset):
slide_row = self.dataset_df.loc[slide_id]
sample = {SlideKey.SLIDE_ID: slide_id}
rel_image_path = slide_row[self.IMAGE_COLUMN]
rel_image_path = slide_row[self.image_column]
sample[SlideKey.IMAGE] = str(self.root_dir / rel_image_path)
# we're replicating this column because we want to propagate the path to the batch
sample[SlideKey.IMAGE_PATH] = sample[SlideKey.IMAGE]
if self.MASK_COLUMN:
rel_mask_path = slide_row[self.MASK_COLUMN]
if self.mask_column:
rel_mask_path = slide_row[self.mask_column]
sample[SlideKey.MASK] = str(self.root_dir / rel_mask_path)
sample[SlideKey.MASK_PATH] = sample[SlideKey.MASK]
sample[SlideKey.LABEL] = slide_row[self.label_column]
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.METADATA_COLUMNS}
sample[SlideKey.METADATA] = {col: slide_row[col] for col in self.metadata_columns}
return sample
@classmethod

Просмотреть файл

@ -8,6 +8,16 @@ from typing import Any, Dict, Union, Optional
from health_cpath.datasets.base_dataset import SlidesDataset
class PandaColumns:
SLIDE_ID = 'image_id'
IMAGE = 'image'
MASK = 'mask'
METADATA = ('data_provider', 'isup_grade', 'gleason_score')
PANDA_CSV_FILENAME = "train.csv"
class PandaDataset(SlidesDataset):
"""Dataset class for loading files from the PANDA challenge dataset.
@ -17,14 +27,6 @@ class PandaDataset(SlidesDataset):
Ref.: https://www.kaggle.com/c/prostate-cancer-grade-assessment/overview
"""
SLIDE_ID_COLUMN = 'image_id'
IMAGE_COLUMN = 'image'
MASK_COLUMN = 'mask'
METADATA_COLUMNS = ('data_provider', 'isup_grade', 'gleason_score')
DEFAULT_CSV_FILENAME = "train.csv"
def __init__(
self,
root: Union[str, Path],
@ -42,9 +44,14 @@ class PandaDataset(SlidesDataset):
label_column=label_column,
n_classes=n_classes,
dataframe_kwargs=dataframe_kwargs,
slide_id_column=PandaColumns.SLIDE_ID,
image_column=PandaColumns.IMAGE,
mask_column=PandaColumns.MASK,
metadata_columns=PandaColumns.METADATA,
default_csv_filename=PANDA_CSV_FILENAME,
)
# PANDA CSV does not come with paths for image and mask files
slide_ids = self.dataset_df.index
self.dataset_df[self.IMAGE_COLUMN] = "train_images/" + slide_ids + ".tiff"
self.dataset_df[self.MASK_COLUMN] = "train_label_masks/" + slide_ids + "_mask.tiff"
self.dataset_df[self.image_column] = "train_images/" + slide_ids + ".tiff"
self.dataset_df[self.mask_column] = "train_label_masks/" + slide_ids + "_mask.tiff"
self.validate_columns()

Просмотреть файл

@ -11,6 +11,17 @@ import pandas as pd
from health_cpath.datasets.base_dataset import DEFAULT_LABEL_COLUMN, SlidesDataset
TCGA_PRAD_DATASET_FILE = "dataset.csv"
class TcgaColumns:
"""Column names for TCGA dataset CSV files."""
IMAGE = "image_path"
LABEL1 = "label1"
LABEL2 = "label2"
class TcgaPradDataset(SlidesDataset):
"""Dataset class for loading TCGA-PRAD slides.
@ -21,10 +32,6 @@ class TcgaPradDataset(SlidesDataset):
- `'label'` (int, 0 or 1): label for predicting positive or negative
"""
IMAGE_COLUMN: str = 'image_path'
DEFAULT_CSV_FILENAME: str = "dataset.csv"
def __init__(
self,
root: Union[str, Path],
@ -39,9 +46,19 @@ class TcgaPradDataset(SlidesDataset):
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
"""
super().__init__(root, dataset_csv, dataset_df, validate_columns=False, label_column=label_column)
super().__init__(
root,
dataset_csv,
dataset_df,
validate_columns=False,
label_column=label_column,
default_csv_filename=TCGA_PRAD_DATASET_FILE,
image_column=TcgaColumns.IMAGE,
)
# Example of how to define a custom label column from existing columns:
self.dataset_df[self.label_column] = (self.dataset_df['label1'] | self.dataset_df['label2']).astype(
self.dataset_df[self.label_column] = (
self.dataset_df[TcgaColumns.LABEL1] | self.dataset_df[TcgaColumns.LABEL2]
).astype(
int
) # noqa: W503
self.validate_columns()

Просмотреть файл

@ -8,11 +8,11 @@ import pandas as pd
from health_cpath.datasets.default_paths import TCGA_CRCK_DATASET_ID
from health_cpath.utils.tcga_utils import extract_fields
from health_cpath.datasets.tcga_prad_dataset import TcgaPradDataset
from health_cpath.datasets.tcga_prad_dataset import TCGA_PRAD_DATASET_FILE
def check_dataset_csv_paths(dataset_dir: Path) -> None:
df = pd.read_csv(dataset_dir / TcgaPradDataset.DEFAULT_CSV_FILENAME)
df = pd.read_csv(dataset_dir / TCGA_PRAD_DATASET_FILE)
for img_path in df.image:
assert (dataset_dir / img_path).is_file()
@ -38,6 +38,6 @@ if __name__ == '__main__':
# takes up to ~20 seconds
df = df.apply(extract_fields, axis='columns', result_type='expand')
df.to_csv(dataset_dir / TcgaPradDataset.DEFAULT_CSV_FILENAME, index=False)
df.to_csv(dataset_dir / TCGA_PRAD_DATASET_FILE, index=False)
check_dataset_csv_paths(dataset_dir)

Просмотреть файл

@ -24,7 +24,7 @@ from health_azure.argparsing import apply_overrides, parse_arguments
from health_cpath.preprocessing.loading import WSIBackend
from health_cpath.utils.montage_config import MontageConfig, create_montage_argparser
from health_cpath.utils.naming import SlideKey
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.datasets.base_dataset import DEFAULT_DATASET_CSV, SlidesDataset
from health_ml.utils.type_annotations import TupleInt3
@ -425,13 +425,13 @@ class MontageCreation(MontageConfig):
dataset = SlidesDataset(root=input_folder)
except Exception as ex:
logging.error("Unable to load dataset.")
file = input_folder / SlidesDataset.DEFAULT_CSV_FILENAME
file = input_folder / DEFAULT_DATASET_CSV
# Print the whole directory tree to check where the problem is.
while str(file) != str(file.root):
logging.debug(f"File: {file}, exists: {file.exists()}")
file = file.parent
raise ValueError(
f"Unable to load dataset. Check if the file {SlidesDataset.DEFAULT_CSV_FILENAME} "
f"Unable to load dataset. Check if the file {DEFAULT_DATASET_CSV} "
f"exists, or provide a file name pattern via --image_glob_pattern. Error: {ex}"
)
return dataset
@ -476,7 +476,7 @@ class MontageCreation(MontageConfig):
:param exclude_items: If True, exclude the list in `items` from the montage. If False, include
only those in the montage.
:param restrict_by_column: The column name that should be used for inclusion/exclusion lists
(default=dataset.SLIDE_ID_COLUMN).
(default=dataset.slide_id_column).
:return: A path to the created montage, or None if no images were available for creating the montage.
"""
if isinstance(dataset, pd.DataFrame):
@ -489,7 +489,7 @@ class MontageCreation(MontageConfig):
if isinstance(dataset, pd.DataFrame):
restrict_by_column = SlideKey.SLIDE_ID.value
else:
restrict_by_column = dataset.SLIDE_ID_COLUMN
restrict_by_column = dataset.slide_id_column
if items:
if exclude_items:
logging.info(f"Using dataset column '{restrict_by_column}' to exclude slides")

Просмотреть файл

@ -95,12 +95,15 @@ class TiffConversionConfig(param.Parameterized):
:param output_folder: The folder where the new dataset csv file will be saved.
"""
new_dataset_df = deepcopy(self.slides_dataset.dataset_df)
new_dataset_df[self.slides_dataset.IMAGE_COLUMN] = (
new_dataset_df[self.slides_dataset.IMAGE_COLUMN]
new_dataset_df[self.slides_dataset.image_column] = (
new_dataset_df[self.slides_dataset.image_column]
.str.replace(AMPERSAND, self.replace_ampersand_by)
.map(lambda x: str(Path(x).with_suffix(TIFF_EXTENSION)))
)
new_dataset_path = output_folder / (self.converted_dataset_csv or self.slides_dataset.DEFAULT_CSV_FILENAME)
new_dataset_file = (
self.converted_dataset_csv if self.converted_dataset_csv else self.slides_dataset.default_csv_filename
)
new_dataset_path = output_folder / new_dataset_file
new_dataset_df.to_csv(new_dataset_path, sep="\t" if new_dataset_path.suffix == ".tsv" else ",")
logging.info(f"Saved new dataset tsv file to {new_dataset_path}")

Просмотреть файл

@ -14,7 +14,7 @@ import torch
from pytorch_lightning import seed_everything
from health_cpath.configs.classification.DeepSMILESlidesPandaBenchmark import DeepSMILESlidesPandaBenchmark
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datasets.panda_dataset import PandaColumns, PandaDataset
from health_cpath.preprocessing.loading import ROIType, WSIBackend
from health_cpath.utils.naming import SlideKey
from testhisto.mocks.base_data_generator import MockHistoDataType
@ -112,7 +112,7 @@ def test_validate_columns(tmp_path: Path) -> None:
background_val=255,
tiles_pos_type=TilesPositioningType.RANDOM,
)
usecols = [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN]
usecols = [PandaColumns.SLIDE_ID, PandaColumns.MASK]
with pytest.raises(ValueError, match=r"Expected columns"):
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols})
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + [PandaDataset.METADATA_COLUMNS[1]]})
_ = PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + [PandaColumns.METADATA[1]]})

Просмотреть файл

@ -11,7 +11,7 @@ import pandas as pd
import torch
from tifffile.tifffile import TiffWriter, PHOTOMETRIC, COMPRESSION
from torch import Tensor
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datasets.panda_dataset import PandaColumns, PANDA_CSV_FILENAME
from health_cpath.preprocessing.tiff_conversion import ResolutionUnit
from testhisto.mocks.base_data_generator import MockHistoDataGenerator, MockHistoDataType, PANDA_N_CLASSES
@ -87,17 +87,15 @@ class MockPandaSlidesGenerator(MockHistoDataGenerator):
list(self.ISUP_GRADE_MAPPING.keys()),
self.n_slides // PANDA_N_CLASSES + 1,
)
mock_metadata: dict = {
col: [] for col in [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN, *PandaDataset.METADATA_COLUMNS]
}
mock_metadata: dict = {col: [] for col in [PandaColumns.SLIDE_ID, PandaColumns.MASK, *PandaColumns.METADATA]}
for slide_id in range(self.n_slides):
mock_metadata[PandaDataset.SLIDE_ID_COLUMN].append(f"_{slide_id}")
mock_metadata[PandaDataset.MASK_COLUMN].append(f"_{slide_id}_mask")
mock_metadata[PandaColumns.SLIDE_ID].append(f"_{slide_id}")
mock_metadata[PandaColumns.MASK].append(f"_{slide_id}_mask")
mock_metadata[self.DATA_PROVIDER].append(np.random.choice(self.DATA_PROVIDERS_VALUES))
mock_metadata[self.ISUP_GRADE].append(isup_grades[slide_id])
mock_metadata[self.GLEASON_SCORE].append(np.random.choice(self.ISUP_GRADE_MAPPING[isup_grades[slide_id]]))
df = pd.DataFrame(data=mock_metadata)
csv_filename = self.dest_data_path / PandaDataset.DEFAULT_CSV_FILENAME
csv_filename = self.dest_data_path / PANDA_CSV_FILENAME
df.to_csv(csv_filename, index=False)
def create_mock_wsi(self, tiles: Tensor) -> Tuple[np.ndarray, Optional[np.ndarray]]:

Просмотреть файл

@ -17,8 +17,8 @@ from health_cpath.utils.montage import (
make_montage_from_dir,
restrict_dataset,
)
from health_cpath.datasets.base_dataset import SlidesDataset
from health_cpath.datasets.panda_dataset import PandaDataset
from health_cpath.datasets.base_dataset import DEFAULT_DATASET_CSV, SlidesDataset
from health_cpath.datasets.panda_dataset import PandaColumns, PandaDataset
from health_cpath.scripts.create_montage import main as script_main
from health_cpath.utils.naming import SlideKey
from testhisto.mocks.base_data_generator import MockHistoDataType
@ -62,8 +62,8 @@ def temp_panda_dataset(tmp_path_factory: pytest.TempPathFactory) -> Generator:
"""A fixture that creates a PandaDataset object with randomly created slides."""
tmp_path = tmp_path_factory.mktemp("mock_panda")
_create_slides_images(tmp_path)
usecols = [PandaDataset.SLIDE_ID_COLUMN, PandaDataset.MASK_COLUMN]
yield PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + list(PandaDataset.METADATA_COLUMNS)})
usecols = [PandaColumns.SLIDE_ID, PandaColumns.MASK]
yield PandaDataset(root=tmp_path, dataframe_kwargs={"usecols": usecols + list(PandaColumns.METADATA)})
@pytest.fixture(scope="module")
@ -86,7 +86,7 @@ def temp_slides_dataset(tmp_path_factory: pytest.TempPathFactory) -> Generator:
SlideKey.LABEL: [f"Label {i}" for i in range(NUM_SLIDES)],
}
df = pd.DataFrame(data=metadata)
csv_filename = tmp_path / SlidesDataset.DEFAULT_CSV_FILENAME
csv_filename = tmp_path / DEFAULT_DATASET_CSV
df.to_csv(csv_filename, index=False)
# Tests fail non-deterministically, saying that the dataset file does not exist (yet). Hence, wait.
wait_until_file_exists(csv_filename)