DOC: Add all `InnerEye/ML` docstrings to ReadTheDocs (#783)
* 📝 Create basic for ML API * 📝 Add ML/configs base doc files * 📝 Finish ML/configs API * 📝 Update augmentations * 📝 Add ML/dataset API docs * 📝 Add rst skeleton for ML/models * 📝 Fix docstring missing newlines * Remove script * 📝 Finish ML/models API docs * 📝 Start ML/SSL API. Fix some formatting issues * 📝 Correct whitespace issues in `:param` * 📝 Fix whitespace errors on `:return` statements * 📝 Fix :return: statements * 📝 Finish ML/SSL API * 📝 Add ML/utils API docs * 📝 Add visualizer docs, fix `:raise` indents * 📝 Fix more issues with the `:raises:` formatting * ♻️ Restructuring folders * 📝 Limit API `toctree` depth * 📝 Add primary InnerEye/ML files API to docs * 📝 Fix and add `InnerEye/ML/*.py` docs * ⚰️ Remove weird `settings.json` change * ♻️ 💡 Address review comments
This commit is contained in:
Родитель
c1b363e158
Коммит
59214c268e
|
@ -12,5 +12,5 @@
|
|||
},
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"files.trimFinalNewlines": true,
|
||||
"files.insertFinalNewline": true
|
||||
"files.insertFinalNewline": true,
|
||||
}
|
||||
|
|
|
@ -168,6 +168,7 @@ class AzureConfig(GenericConfig):
|
|||
"""
|
||||
Creates an AzureConfig object with default values, with the keys/secrets populated from values in the
|
||||
given YAML file. If a `project_root` folder is provided, a private settings file is read from there as well.
|
||||
|
||||
:param yaml_file_path: Path to the YAML file that contains values to create the AzureConfig
|
||||
:param project_root: A folder in which to search for a private settings file.
|
||||
:return: AzureConfig with values populated from the yaml files.
|
||||
|
@ -231,6 +232,7 @@ class AzureConfig(GenericConfig):
|
|||
def fetch_run(self, run_recovery_id: str) -> Run:
|
||||
"""
|
||||
Gets an instantiated Run object for a given run recovery ID (format experiment_name:run_id).
|
||||
|
||||
:param run_recovery_id: A run recovery ID (format experiment_name:run_id)
|
||||
"""
|
||||
return fetch_run(workspace=self.get_workspace(), run_recovery_id=run_recovery_id)
|
||||
|
|
|
@ -39,6 +39,7 @@ def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
|
|||
Creates a dictionary with git-related information, like branch and commit ID. The dictionary key is a string
|
||||
that can be used as a tag on an AzureML run, the dictionary value is the git information. If git information
|
||||
is passed in via commandline arguments, those take precedence over information read out from the repository.
|
||||
|
||||
:param azure_config: An AzureConfig object specifying git-related commandline args.
|
||||
:return: A dictionary mapping from tag name to git info.
|
||||
"""
|
||||
|
@ -56,6 +57,7 @@ def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
|
|||
def additional_run_tags(azure_config: AzureConfig, commandline_args: str) -> Dict[str, str]:
|
||||
"""
|
||||
Gets the set of tags that will be added to the AzureML run as metadata, like git status and user name.
|
||||
|
||||
:param azure_config: The configurations for the present AzureML job
|
||||
:param commandline_args: A string that holds all commandline arguments that were used for the present run.
|
||||
"""
|
||||
|
@ -77,6 +79,7 @@ def additional_run_tags(azure_config: AzureConfig, commandline_args: str) -> Dic
|
|||
def create_experiment_name(azure_config: AzureConfig) -> str:
|
||||
"""
|
||||
Gets the name of the AzureML experiment. This is taken from the commandline, or from the git branch.
|
||||
|
||||
:param azure_config: The object containing all Azure-related settings.
|
||||
:return: The name to use for the AzureML experiment.
|
||||
"""
|
||||
|
@ -104,7 +107,7 @@ def create_dataset_configs(azure_config: AzureConfig,
|
|||
:param all_dataset_mountpoints: When using the datasets in AzureML, these are the per-dataset mount points.
|
||||
:param all_local_datasets: The paths for all local versions of the datasets.
|
||||
:return: A list of DatasetConfig objects, in the same order as datasets were provided in all_azure_dataset_ids,
|
||||
omitting datasets with an empty name.
|
||||
omitting datasets with an empty name.
|
||||
"""
|
||||
datasets: List[DatasetConfig] = []
|
||||
num_local = len(all_local_datasets)
|
||||
|
@ -147,6 +150,7 @@ def create_runner_parser(model_config_class: type = None) -> argparse.ArgumentPa
|
|||
"""
|
||||
Creates a commandline parser, that understands all necessary arguments for running a script in Azure,
|
||||
plus all arguments for the given class. The class must be a subclass of GenericConfig.
|
||||
|
||||
:param model_config_class: A class that contains the model-specific parameters.
|
||||
:return: An instance of ArgumentParser.
|
||||
"""
|
||||
|
@ -167,11 +171,12 @@ def parse_args_and_add_yaml_variables(parser: ArgumentParser,
|
|||
"""
|
||||
Reads arguments from sys.argv, modifies them with secrets from local YAML files,
|
||||
and parses them using the given argument parser.
|
||||
|
||||
:param project_root: The root folder for the whole project. Only used to access a private settings file.
|
||||
:param parser: The parser to use.
|
||||
:param yaml_config_file: The path to the YAML file that contains values to supply into sys.argv.
|
||||
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does not
|
||||
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
|
||||
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
|
||||
:return: The parsed arguments, and overrides
|
||||
"""
|
||||
settings_from_yaml = read_all_settings(yaml_config_file, project_root=project_root)
|
||||
|
@ -183,6 +188,7 @@ def parse_args_and_add_yaml_variables(parser: ArgumentParser,
|
|||
def _create_default_namespace(parser: ArgumentParser) -> Namespace:
|
||||
"""
|
||||
Creates an argparse Namespace with all parser-specific default values set.
|
||||
|
||||
:param parser: The parser to work with.
|
||||
:return:
|
||||
"""
|
||||
|
@ -207,10 +213,12 @@ def parse_arguments(parser: ArgumentParser,
|
|||
Parses a list of commandline arguments with a given parser, and adds additional information read
|
||||
from YAML files. Returns results broken down into a full arguments dictionary, a dictionary of arguments
|
||||
that were set to non-default values, and unknown arguments.
|
||||
|
||||
:param parser: The parser to use
|
||||
:param settings_from_yaml: A dictionary of settings read from a YAML config file.
|
||||
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does not
|
||||
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
|
||||
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
|
||||
|
||||
:param args: Arguments to parse. If not given, use those in sys.argv
|
||||
:return: The parsed arguments, and overrides
|
||||
"""
|
||||
|
@ -261,6 +269,7 @@ def run_duration_string_to_seconds(s: str) -> Optional[int]:
|
|||
Parse a string that represents a timespan, and returns it converted into seconds. The string is expected to be
|
||||
floating point number with a single character suffix s, m, h, d for seconds, minutes, hours, day.
|
||||
Examples: '3.5h', '2d'. If the argument is an empty string, None is returned.
|
||||
|
||||
:param s: The string to parse.
|
||||
:return: The timespan represented in the string converted to seconds.
|
||||
"""
|
||||
|
|
|
@ -47,6 +47,7 @@ def split_recovery_id(id: str) -> Tuple[str, str]:
|
|||
The argument can be in the format 'experiment_name:run_id',
|
||||
or just a run ID like user_branch_abcde12_123. In the latter case, everything before the last
|
||||
two alphanumeric parts is assumed to be the experiment name.
|
||||
|
||||
:param id:
|
||||
:return: experiment name and run name
|
||||
"""
|
||||
|
@ -74,9 +75,10 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
|
|||
Finds an existing run in an experiment, based on a recovery ID that contains the experiment ID
|
||||
and the actual RunId. The run can be specified either in the experiment_name:run_id format,
|
||||
or just the run_id.
|
||||
|
||||
:param workspace: the configured AzureML workspace to search for the experiment.
|
||||
:param run_recovery_id: The Run to find. Either in the full recovery ID format, experiment_name:run_id
|
||||
or just the run_id
|
||||
or just the run_id
|
||||
:return: The AzureML run.
|
||||
"""
|
||||
return get_aml_run_from_run_id(aml_workspace=workspace, run_id=run_recovery_id)
|
||||
|
@ -85,6 +87,7 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
|
|||
def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
|
||||
"""
|
||||
Fetch the runs in an experiment.
|
||||
|
||||
:param experiment: the experiment to fetch runs from
|
||||
:param filters: a list of run status to include. Must be subset of [Running, Completed, Failed, Canceled].
|
||||
:return: the list of runs in the experiment
|
||||
|
@ -107,10 +110,11 @@ def fetch_child_runs(
|
|||
"""
|
||||
Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
|
||||
and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs).
|
||||
|
||||
:param run: parent run to fetch child run from
|
||||
:param status: if provided, returns only child runs with this status
|
||||
:param expected_number_cross_validation_splits: when recovering child runs from AML hyperdrive
|
||||
sometimes the get_children function fails to retrieve all children. If the number of child runs
|
||||
sometimes the get_children function fails to retrieve all children. If the number of child runs
|
||||
retrieved by AML is lower than the expected number of splits, we try to retrieve them manually.
|
||||
"""
|
||||
if is_ensemble_run(run):
|
||||
|
@ -159,6 +163,7 @@ def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
|
|||
def to_azure_friendly_container_path(path: Path) -> str:
|
||||
"""
|
||||
Converts a path an Azure friendly container path by replacing "\\", "//" with "/" so it can be in the form: a/b/c.
|
||||
|
||||
:param path: Original path
|
||||
:return: Converted path
|
||||
"""
|
||||
|
@ -168,6 +173,7 @@ def to_azure_friendly_container_path(path: Path) -> str:
|
|||
def is_offline_run_context(run_context: Run) -> bool:
|
||||
"""
|
||||
Tells if a run_context is offline by checking if it has an experiment associated with it.
|
||||
|
||||
:param run_context: Context of the run to check
|
||||
:return:
|
||||
"""
|
||||
|
@ -177,6 +183,7 @@ def is_offline_run_context(run_context: Run) -> bool:
|
|||
def get_run_context_or_default(run: Optional[Run] = None) -> Run:
|
||||
"""
|
||||
Returns the context of the run, if run is not None. If run is None, returns the context of the current run.
|
||||
|
||||
:param run: Run to retrieve context for. If None, retrieve ocntext of current run.
|
||||
:return: Run context
|
||||
"""
|
||||
|
@ -186,6 +193,7 @@ def get_run_context_or_default(run: Optional[Run] = None) -> Run:
|
|||
def get_cross_validation_split_index(run: Run) -> int:
|
||||
"""
|
||||
Gets the cross validation index from the run's tags or returns the default
|
||||
|
||||
:param run: Run context from which to get index
|
||||
:return: The cross validation split index
|
||||
"""
|
||||
|
@ -204,6 +212,7 @@ def is_cross_validation_child_run(run: Run) -> bool:
|
|||
"""
|
||||
Checks the provided run's tags to determine if it is a cross validation child run
|
||||
(which is the case if the split index >=0)
|
||||
|
||||
:param run: Run to check.
|
||||
:return: True if cross validation run. False otherwise.
|
||||
"""
|
||||
|
@ -213,6 +222,7 @@ def is_cross_validation_child_run(run: Run) -> bool:
|
|||
def strip_prefix(string: str, prefix: str) -> str:
|
||||
"""
|
||||
Returns the string without the prefix if it has the prefix, otherwise the string unchanged.
|
||||
|
||||
:param string: Input string.
|
||||
:param prefix: Prefix to remove from input string.
|
||||
:return: Input string with prefix removed.
|
||||
|
@ -226,6 +236,7 @@ def get_all_environment_files(project_root: Path) -> List[Path]:
|
|||
"""
|
||||
Returns a list of all Conda environment files that should be used. This is firstly the InnerEye conda file,
|
||||
and possibly a second environment.yml file that lives at the project root folder.
|
||||
|
||||
:param project_root: The root folder of the code that starts the present training run.
|
||||
:return: A list with 1 or 2 entries that are conda environment files.
|
||||
"""
|
||||
|
@ -260,6 +271,7 @@ def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Pa
|
|||
Downloads a single file from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs").
|
||||
For example, if blobs_path = "foo/bar.csv", then the run result file "outputs/foo/bar.csv" will be downloaded
|
||||
to <destination>/bar.csv (the directory will be stripped off).
|
||||
|
||||
:param blob_path: The name of the file to download.
|
||||
:param run: The AzureML run to download the files from
|
||||
:param destination: Local path to save the downloaded blob to.
|
||||
|
@ -287,6 +299,7 @@ def download_run_outputs_by_prefix(
|
|||
have a given prefix (folder structure). When saving, the prefix string will be stripped off. For example,
|
||||
if blobs_prefix = "foo", and the run has a file "outputs/foo/bar.csv", it will be downloaded to destination/bar.csv.
|
||||
If there is in addition a file "foo.txt", that file will be skipped.
|
||||
|
||||
:param blobs_prefix: The prefix for all files in "outputs" that should be downloaded.
|
||||
:param run: The AzureML run to download the files from.
|
||||
:param destination: Local path to save the downloaded blobs to.
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
|||
def _is_empty(item: Any) -> bool:
|
||||
"""
|
||||
Returns True if the argument has length 0.
|
||||
|
||||
:param item: Object to check.
|
||||
:return: True if the argument has length 0. False otherwise.
|
||||
"""
|
||||
|
@ -18,6 +19,7 @@ def _is_empty(item: Any) -> bool:
|
|||
def _is_empty_or_empty_string_list(item: Any) -> bool:
|
||||
"""
|
||||
Returns True if the argument has length 0, or a list with a single element that has length 0.
|
||||
|
||||
:param item: Object to check.
|
||||
:return: True if argument has length 0, or a list with a single element that has length 0. False otherwise.
|
||||
"""
|
||||
|
@ -32,6 +34,7 @@ def value_to_string(x: object) -> str:
|
|||
"""
|
||||
Returns a string representation of x, with special treatment of Enums (return their value)
|
||||
and lists (return comma-separated list).
|
||||
|
||||
:param x: Object to convert to string
|
||||
:return: The string representation of the object.
|
||||
Special cases: For Enums, returns their value, for lists, returns a comma-separated list.
|
||||
|
|
|
@ -19,10 +19,11 @@ PYTEST_RESULTS_FILE = Path("test-results-on-azure-ml.xml")
|
|||
def run_pytest(pytest_mark: str, outputs_folder: Path) -> Tuple[bool, Path]:
|
||||
"""
|
||||
Runs pytest on the whole test suite, restricting to the tests that have the given PyTest mark.
|
||||
|
||||
:param pytest_mark: The PyTest mark to use for filtering out the tests to run.
|
||||
:param outputs_folder: The folder into which the test result XML file should be written.
|
||||
:return: True if PyTest found tests to execute and completed successfully, False otherwise.
|
||||
Also returns the path to the generated PyTest results file.
|
||||
Also returns the path to the generated PyTest results file.
|
||||
"""
|
||||
from _pytest.main import ExitCode
|
||||
_outputs_file = outputs_folder / PYTEST_RESULTS_FILE
|
||||
|
@ -43,6 +44,7 @@ def download_pytest_result(run: Run, destination_folder: Path = Path.cwd()) -> P
|
|||
"""
|
||||
Downloads the pytest result file that is stored in the output folder of the given AzureML run.
|
||||
If there is no pytest result file, throw an Exception.
|
||||
|
||||
:param run: The run from which the files should be read.
|
||||
:param destination_folder: The folder into which the PyTest result file is downloaded.
|
||||
:return: The path (folder and filename) of the downloaded file.
|
||||
|
|
|
@ -28,6 +28,7 @@ class SecretsHandling:
|
|||
def __init__(self, project_root: Path) -> None:
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param project_root: The root folder of the project that starts the InnerEye run.
|
||||
"""
|
||||
self.project_root = project_root
|
||||
|
@ -36,8 +37,9 @@ class SecretsHandling:
|
|||
"""
|
||||
Reads the secrets from file in YAML format, and returns the contents as a dictionary. The YAML file is expected
|
||||
in the project root directory.
|
||||
|
||||
:param secrets_to_read: The list of secret names to read from the YAML file. These will be converted to
|
||||
uppercase.
|
||||
uppercase.
|
||||
:return: A dictionary with secrets, or None if the file does not exist.
|
||||
"""
|
||||
secrets_file = self.project_root / fixed_paths.PROJECT_SECRETS_FILE
|
||||
|
@ -57,8 +59,9 @@ class SecretsHandling:
|
|||
Attempts to read secrets from the project secret file. If there is no secrets file, it returns all secrets
|
||||
in secrets_to_read read from environment variables. When reading from environment, if an expected
|
||||
secret is not found, its value will be None.
|
||||
|
||||
:param secrets_to_read: The list of secret names to read from the YAML file. These will be converted to
|
||||
uppercase.
|
||||
uppercase.
|
||||
"""
|
||||
# Read all secrets from a local file if present, and sets the matching environment variables.
|
||||
# If no secrets file is present, no environment variable is modified or created.
|
||||
|
@ -69,9 +72,10 @@ class SecretsHandling:
|
|||
def get_secret_from_environment(self, name: str, allow_missing: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Gets a password or key from the secrets file or environment variables.
|
||||
|
||||
:param name: The name of the environment variable to read. It will be converted to uppercase.
|
||||
:param allow_missing: If true, the function returns None if there is no entry of the given name in
|
||||
any of the places searched. If false, missing entries will raise a ValueError.
|
||||
any of the places searched. If false, missing entries will raise a ValueError.
|
||||
:return: Value of the secret. None, if there is no value and allow_missing is True.
|
||||
"""
|
||||
|
||||
|
@ -99,10 +103,11 @@ def read_all_settings(project_settings_file: Optional[Path] = None,
|
|||
the `project_root` folder. Settings in the private settings file
|
||||
override those in the project settings. Both settings files are expected in YAML format, with an entry called
|
||||
'variables'.
|
||||
|
||||
:param project_settings_file: The first YAML settings file to read.
|
||||
:param project_root: The folder that can contain a 'InnerEyePrivateSettings.yml' file.
|
||||
:return: A dictionary mapping from string to variable value. The dictionary key is the union of variable names
|
||||
found in the two settings files.
|
||||
found in the two settings files.
|
||||
"""
|
||||
private_settings_file = None
|
||||
if project_root and project_root.is_dir():
|
||||
|
@ -117,10 +122,11 @@ def read_settings_and_merge(project_settings_file: Optional[Path] = None,
|
|||
file is read into a dictionary, then the private settings file is read. Settings in the private settings file
|
||||
override those in the project settings. Both settings files are expected in YAML format, with an entry called
|
||||
'variables'.
|
||||
|
||||
:param project_settings_file: The first YAML settings file to read.
|
||||
:param private_settings_file: The second YAML settings file to read. Settings in this file has higher priority.
|
||||
:return: A dictionary mapping from string to variable value. The dictionary key is the union of variable names
|
||||
found in the two settings files.
|
||||
found in the two settings files.
|
||||
"""
|
||||
result = dict()
|
||||
if project_settings_file:
|
||||
|
@ -138,6 +144,7 @@ def read_settings_yaml_file(yaml_file: Path) -> Dict[str, Any]:
|
|||
"""
|
||||
Reads a YAML file, that is expected to contain an entry 'variables'. Returns the dictionary for the 'variables'
|
||||
section of the file.
|
||||
|
||||
:param yaml_file: The yaml file to read.
|
||||
:return: A dictionary with the variables from the yaml file.
|
||||
"""
|
||||
|
|
|
@ -45,6 +45,7 @@ class AMLTensorBoardMonitorConfig(GenericConfig):
|
|||
def monitor(monitor_config: AMLTensorBoardMonitorConfig, azure_config: AzureConfig) -> None:
|
||||
"""
|
||||
Starts TensorBoard monitoring as per the provided arguments.
|
||||
|
||||
:param monitor_config: The config containing information on which runs that need be monitored.
|
||||
:param azure_config: An AzureConfig object with secrets/keys to access the workspace.
|
||||
"""
|
||||
|
@ -93,9 +94,10 @@ def main(settings_yaml_file: Optional[Path] = None,
|
|||
"""
|
||||
Parses the commandline arguments, and based on those, starts the Tensorboard monitoring for the AzureML runs
|
||||
supplied on the commandline.
|
||||
|
||||
:param settings_yaml_file: The YAML file that contains all information for accessing Azure.
|
||||
:param project_root: The root folder that contains all code for the present run. This is only used to locate
|
||||
a private settings file InnerEyePrivateSettings.yml.
|
||||
a private settings file InnerEyePrivateSettings.yml.
|
||||
"""
|
||||
monitor_config = AMLTensorBoardMonitorConfig.parse_args()
|
||||
settings_yaml_file = settings_yaml_file or monitor_config.settings
|
||||
|
|
|
@ -88,9 +88,10 @@ MINIMUM_INSTANCE_COUNT = 10
|
|||
def compose_distribution_comparisons(file_contents: List[List[List[str]]]) -> List[str]:
|
||||
"""
|
||||
Composes comparisons as detailed above.
|
||||
|
||||
:param file_contents: two or more lists of rows, where each "rows" is returned by read_csv_file on
|
||||
(typically) a statistics.csv file
|
||||
:return a list of lines to print
|
||||
(typically) a statistics.csv file
|
||||
:return: a list of lines to print
|
||||
"""
|
||||
value_lists: List[Dict[str, List[float]]] = [parse_values(rows) for rows in file_contents]
|
||||
return compose_distribution_comparisons_on_lists(value_lists)
|
||||
|
@ -121,6 +122,7 @@ def mann_whitney_on_key(key: str, lists: List[List[float]]) -> List[Tuple[Tuple[
|
|||
Applies Mann-Whitney test to all sets of values (in lists) for the given key,
|
||||
and return a line of results, paired with some values for ordering purposes.
|
||||
Member lists with fewer than MINIMUM_INSTANCE_COUNT items are discarded.
|
||||
|
||||
:param key: statistic name; "Vol" statistics have mm^3 replaced by cm^3 for convenience.
|
||||
:param lists: list of lists of values
|
||||
"""
|
||||
|
@ -185,7 +187,7 @@ def roc_value(lst1: List[float], lst2: List[float]) -> float:
|
|||
:param lst1: a list of numbers
|
||||
:param lst2: another list of numbers
|
||||
:return: the proportion of pairs (x, y), where x is from lst1 and y is from lst2, for which
|
||||
x < y, with x == y counting as half an instance.
|
||||
x < y, with x == y counting as half an instance.
|
||||
"""
|
||||
if len(lst1) == 0 or len(lst2) == 0:
|
||||
return 0.5
|
||||
|
@ -256,6 +258,7 @@ def read_csv_file(input_file: str) -> List[List[str]]:
|
|||
"""
|
||||
Reads and returns the contents of a csv file. Empty rows (which can
|
||||
result from end-of-line mismatches) are dropped.
|
||||
|
||||
:param input_file: path to a file in csv format
|
||||
:return: list of rows from the file
|
||||
"""
|
||||
|
@ -270,7 +273,7 @@ def compare_scores_across_institutions(metrics_file: str, splits_to_use: str = "
|
|||
:param splits_to_use: a comma-separated list of split names
|
||||
:param mode_to_use: test, validation etc
|
||||
:return: a list of comparison lines between pairs of splits. If splits_to_use is non empty,
|
||||
only pairs involving at least one split from that set are compared.
|
||||
only pairs involving at least one split from that set are compared.
|
||||
"""
|
||||
valid_splits = set(splits_to_use.split(",")) if splits_to_use else None
|
||||
metrics = pd.read_csv(metrics_file)
|
||||
|
@ -327,7 +330,7 @@ def get_arguments(arglist: List[str] = None) -> Tuple[Optional[argparse.Namespac
|
|||
The value of the "-a" switch is one or more split names; pairs of splits not including these
|
||||
will not be compared.
|
||||
:return: parsed arguments and identifier for pattern (1, 2, 3 as above), or None, None if none of the
|
||||
patterns are followed
|
||||
patterns are followed
|
||||
"""
|
||||
# Use argparse because we want to have mandatory non-switch arguments, which GenericConfig doesn't support.
|
||||
parser = argparse.ArgumentParser("Run Mann-Whitney tests")
|
||||
|
|
|
@ -72,6 +72,7 @@ def report_structure_extremes(dataset_dir: str, azure_config: AzureConfig) -> No
|
|||
Writes structure-extreme lines for the subjects in a directory.
|
||||
If there are any structures with missing slices, a ValueError is raised after writing all the lines.
|
||||
This allows a build failure to be triggered when such structures exist.
|
||||
|
||||
:param azure_config: An object with all necessary information for accessing Azure.
|
||||
:param dataset_dir: directory containing subject subdirectories with integer names.
|
||||
"""
|
||||
|
@ -129,7 +130,7 @@ def report_structure_extremes_for_subject(subj_dir: str, series_id: str) -> Iter
|
|||
"""
|
||||
:param subj_dir: subject directory, containing <structure>.nii.gz files
|
||||
:param series_id: series identifier for the subject
|
||||
Yields a line for every <structure>.nii.gz file in the directory.
|
||||
Yields a line for every <structure>.nii.gz file in the directory.
|
||||
"""
|
||||
subject = os.path.basename(subj_dir)
|
||||
series_prefix = "" if series_id is None else series_id[:8]
|
||||
|
@ -174,7 +175,7 @@ def extent_list(presence: np.array, max_value: int) -> Tuple[List[int], List[str
|
|||
:param presence: a 1-D array of distinct integers in increasing order.
|
||||
:param max_value: any integer, not necessarily related to presence
|
||||
:return: two tuples: (1) a list of the minimum and maximum values of presence, and max_value;
|
||||
(2) a list of strings, each denoting a missing range of values within "presence".
|
||||
(2) a list of strings, each denoting a missing range of values within "presence".
|
||||
"""
|
||||
if len(presence) == 0:
|
||||
return [-1, -1, max_value], []
|
||||
|
|
|
@ -97,6 +97,7 @@ class WilcoxonTestConfig(GenericConfig):
|
|||
def calculate_statistics(dist1: Dict[str, float], dist2: Dict[str, float], factor: float) -> Dict[str, float]:
|
||||
"""
|
||||
Select common pairs and run the hypothesis test.
|
||||
|
||||
:param dist1: mapping from keys to scores
|
||||
:param dist2: mapping from keys (some or all in common with dist1) to scores
|
||||
:param factor: factor to divide the Wilcoxon "z" value by to determine p-value
|
||||
|
@ -131,10 +132,11 @@ def difference_counts(values1: List[float], values2: List[float]) -> Tuple[int,
|
|||
"""
|
||||
Returns the number of corresponding pairs from vals1 and vals2
|
||||
in which val1 > val2 and val2 > val1 respectively.
|
||||
|
||||
:param values1: list of values
|
||||
:param values2: list of values, same length as values1
|
||||
:return: number of pairs in which first value is greater than second, and number of pairs
|
||||
in which second is greater than first
|
||||
in which second is greater than first
|
||||
"""
|
||||
n1 = 0
|
||||
n2 = 0
|
||||
|
@ -159,6 +161,7 @@ def evaluate_data_pair(data1: Dict[str, Dict[str, float]], data2: Dict[str, Dict
|
|||
-> Dict[str, Dict[str, float]]:
|
||||
"""
|
||||
Find and compare dice scores for each structure
|
||||
|
||||
:param data1: dictionary from structure names, to dictionary from subjects to scores
|
||||
:param data2: another such dictionary, sharing some structure names
|
||||
:param is_raw_p_value: whether to use "raw" Wilcoxon z values when calculating p values (rather than reduce)
|
||||
|
@ -240,6 +243,7 @@ def wilcoxon_signed_rank_test(args: WilcoxonTestConfig,
|
|||
"""
|
||||
Reads data from a csv file, and performs all pairwise comparisons, except if --against was specified,
|
||||
compare every other run against the "--against" run.
|
||||
|
||||
:param args: parsed command line parameters
|
||||
:param name_shortener: optional function to shorten names to make graphs and tables more legible
|
||||
"""
|
||||
|
@ -262,6 +266,7 @@ def run_wilcoxon_test_on_data(data: Dict[str, Dict[str, Dict[str, float]]],
|
|||
"""
|
||||
Performs all pairwise comparisons on the provided data, except if "against" was specified,
|
||||
compare every other run against the "against" run.
|
||||
|
||||
:param data: scores such that data[run][structure][subject] = dice score
|
||||
:param against: runs to compare against; or None to compare all against all
|
||||
:param raw: whether to interpret Wilcoxon Z values "raw" or apply a correction
|
||||
|
|
|
@ -74,9 +74,10 @@ def get_best_epoch_results_path(mode: ModelExecutionMode,
|
|||
"""
|
||||
For a given model execution mode, creates the relative results path
|
||||
in the form BEST_EPOCH_FOLDER_NAME/(Train, Test or Val)
|
||||
|
||||
:param mode: model execution mode
|
||||
:param model_proc: whether this is for an ensemble or single model. If ensemble, we return a different path
|
||||
to avoid colliding with the results from the single model that may have been created earlier in the same run.
|
||||
to avoid colliding with the results from the single model that may have been created earlier in the same run.
|
||||
"""
|
||||
subpath = Path(BEST_EPOCH_FOLDER_NAME) / mode.value
|
||||
if model_proc == ModelProcessing.ENSEMBLE_CREATION:
|
||||
|
@ -108,6 +109,7 @@ def any_pairwise_larger(items1: Any, items2: Any) -> bool:
|
|||
def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optional[str]]) -> None:
|
||||
"""
|
||||
Raises an exception if 'actual' is not any of the given valid values.
|
||||
|
||||
:param message: The prefix for the error message.
|
||||
:param actual: The actual value.
|
||||
:param valid: The set of valid strings that 'actual' is allowed to take on.
|
||||
|
@ -136,8 +138,9 @@ def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
|
|||
"""
|
||||
Instructs the Python logging libraries to start writing logs to stdout up to the given logging level.
|
||||
Logging will use a timestamp as the prefix, using UTC.
|
||||
|
||||
:param log_level: The logging level. All logging message with a level at or above this level will be written to
|
||||
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
|
||||
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
|
||||
"""
|
||||
log_level = standardize_log_level(log_level)
|
||||
logger = logging.getLogger()
|
||||
|
@ -186,6 +189,7 @@ def logging_to_file(file_path: Path) -> None:
|
|||
Instructs the Python logging libraries to start writing logs to the given file.
|
||||
Logging will use a timestamp as the prefix, using UTC. The logging level will be the same as defined for
|
||||
logging to stdout.
|
||||
|
||||
:param file_path: The path and name of the file to write to.
|
||||
"""
|
||||
# This function can be called multiple times, and should only add a handler during the first call.
|
||||
|
@ -219,6 +223,7 @@ def logging_only_to_file(file_path: Path, stdout_log_level: Union[int, str] = lo
|
|||
Redirects logging to the specified file, undoing that on exit. If logging is currently going
|
||||
to stdout, messages at level stdout_log_level or higher (typically ERROR) are also sent to stdout.
|
||||
Usage: with logging_only_to_file(my_log_path): do_stuff()
|
||||
|
||||
:param file_path: file to log to
|
||||
:param stdout_log_level: mininum level for messages to also go to stdout
|
||||
"""
|
||||
|
@ -253,6 +258,7 @@ def logging_section(gerund: str) -> Generator:
|
|||
to help people locate particular sections. Usage:
|
||||
with logging_section("doing this and that"):
|
||||
do_this_and_that()
|
||||
|
||||
:param gerund: string expressing what happens in this section of the log.
|
||||
"""
|
||||
from time import time
|
||||
|
@ -301,16 +307,19 @@ def check_properties_are_not_none(obj: Any, ignore: Optional[List[str]] = None)
|
|||
|
||||
def initialize_instance_variables(func: Callable) -> Callable:
|
||||
"""
|
||||
Automatically assigns the input parameters.
|
||||
Automatically assigns the input parameters. Example usage::
|
||||
|
||||
class process:
|
||||
@initialize_instance_variables
|
||||
def __init__(self, cmd, reachable=False, user='root'):
|
||||
pass
|
||||
p = process('halt', True)
|
||||
print(p.cmd, p.reachable, p.user)
|
||||
|
||||
Outputs::
|
||||
|
||||
('halt', True, 'root')
|
||||
|
||||
>>> class process:
|
||||
... @initialize_instance_variables
|
||||
... def __init__(self, cmd, reachable=False, user='root'):
|
||||
... pass
|
||||
>>> p = process('halt', True)
|
||||
>>> # noinspection PyUnresolvedReferences
|
||||
>>> p.cmd, p.reachable, p.user
|
||||
('halt', True, 'root')
|
||||
"""
|
||||
names, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func)
|
||||
|
||||
|
@ -354,6 +363,7 @@ def is_gpu_tensor(data: Any) -> bool:
|
|||
def print_exception(ex: Exception, message: str, logger_fn: Callable = logging.error) -> None:
|
||||
"""
|
||||
Prints information about an exception, and the full traceback info.
|
||||
|
||||
:param ex: The exception that was caught.
|
||||
:param message: An additional prefix that is printed before the exception itself.
|
||||
:param logger_fn: The logging function to use for logging this exception
|
||||
|
@ -366,6 +376,7 @@ def print_exception(ex: Exception, message: str, logger_fn: Callable = logging.e
|
|||
def namespace_to_path(namespace: str, root: PathOrString = repository_root_directory()) -> Path:
|
||||
"""
|
||||
Given a namespace (in form A.B.C) and an optional root directory R, create a path R/A/B/C
|
||||
|
||||
:param namespace: Namespace to convert to path
|
||||
:param root: Path to prefix (default is project root)
|
||||
:return:
|
||||
|
@ -377,6 +388,7 @@ def path_to_namespace(path: Path, root: PathOrString = repository_root_directory
|
|||
"""
|
||||
Given a path (in form R/A/B/C) and an optional root directory R, create a namespace A.B.C.
|
||||
If root is provided, then path must be a relative child to it.
|
||||
|
||||
:param path: Path to convert to namespace
|
||||
:param root: Path prefix to remove from namespace (default is project root)
|
||||
:return:
|
||||
|
|
|
@ -14,6 +14,7 @@ from InnerEye.Common.type_annotations import PathOrString
|
|||
def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
|
||||
"""
|
||||
Gets the full path to the root directory that holds the present repository.
|
||||
|
||||
:param path: if provided, a relative path to append to the absolute path to the repository root.
|
||||
:return: The full path to the repository's root directory, with symlinks resolved if any.
|
||||
"""
|
||||
|
@ -28,6 +29,7 @@ def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
|
|||
def repository_parent_directory(path: Optional[PathOrString] = None) -> Path:
|
||||
"""
|
||||
Gets the full path to the parent directory that holds the present repository.
|
||||
|
||||
:param path: if provided, a relative path to append to the absolute path to the repository root.
|
||||
:return: The full path to the repository's root directory, with symlinks resolved if any.
|
||||
"""
|
||||
|
|
|
@ -99,9 +99,11 @@ class GenericConfig(param.Parameterized):
|
|||
def __init__(self, should_validate: bool = True, throw_if_unknown_param: bool = False, **params: Any):
|
||||
"""
|
||||
Instantiates the config class, ignoring parameters that are not overridable.
|
||||
|
||||
:param should_validate: If True, the validate() method is called directly after init.
|
||||
:param throw_if_unknown_param: If True, raise an error if the provided "params" contains any key that does not
|
||||
correspond to an attribute of the class.
|
||||
|
||||
:param params: Parameters to set.
|
||||
"""
|
||||
# check if illegal arguments are passed in
|
||||
|
@ -157,6 +159,7 @@ class GenericConfig(param.Parameterized):
|
|||
"""
|
||||
Adds all overridable fields of the current class to the given argparser.
|
||||
Fields that are marked as readonly, constant or private are ignored.
|
||||
|
||||
:param parser: Parser to add properties to.
|
||||
"""
|
||||
|
||||
|
@ -165,6 +168,7 @@ class GenericConfig(param.Parameterized):
|
|||
Parse a string as a bool. Supported values are case insensitive and one of:
|
||||
'on', 't', 'true', 'y', 'yes', '1' for True
|
||||
'off', 'f', 'false', 'n', 'no', '0' for False.
|
||||
|
||||
:param x: string to test.
|
||||
:return: Bool value if string valid, otherwise a ValueError is raised.
|
||||
"""
|
||||
|
@ -179,6 +183,7 @@ class GenericConfig(param.Parameterized):
|
|||
"""
|
||||
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
|
||||
Throw exception if it is not supported.
|
||||
|
||||
:param _p: parameter to get type and nargs for.
|
||||
:return: Type
|
||||
"""
|
||||
|
@ -222,6 +227,7 @@ class GenericConfig(param.Parameterized):
|
|||
Add a boolean argument.
|
||||
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
|
||||
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
|
||||
|
||||
:param parser: parser to add a boolean argument to.
|
||||
:param k: argument name.
|
||||
:param p: boolean parameter.
|
||||
|
@ -318,6 +324,7 @@ class GenericConfig(param.Parameterized):
|
|||
"""
|
||||
Logs a warning for every parameter whose value is not as given in "values", other than those
|
||||
in keys_to_ignore.
|
||||
|
||||
:param values: override dictionary, parameter names to values
|
||||
:param keys_to_ignore: set of dictionary keys not to report on
|
||||
:return: None
|
||||
|
@ -347,6 +354,7 @@ def create_from_matching_params(from_object: param.Parameterized, cls_: Type[T])
|
|||
Creates an object of the given target class, and then copies all attributes from the `from_object` to
|
||||
the newly created object, if there is a matching attribute. The target class must be a subclass of
|
||||
param.Parameterized.
|
||||
|
||||
:param from_object: The object to read attributes from.
|
||||
:param cls_: The name of the class for the newly created object.
|
||||
:return: An instance of cls_
|
||||
|
|
|
@ -33,6 +33,7 @@ class OutputFolderForTests:
|
|||
"""
|
||||
Creates a full path for the given file or folder name relative to the root directory stored in the present
|
||||
object.
|
||||
|
||||
:param file_or_folder_name: Name of file or folder to be created under root_dir
|
||||
"""
|
||||
return self.root_dir / file_or_folder_name
|
||||
|
@ -40,6 +41,7 @@ class OutputFolderForTests:
|
|||
def make_sub_dir(self, dir_name: str) -> Path:
|
||||
"""
|
||||
Makes a sub directory under root_dir
|
||||
|
||||
:param dir_name: Name of subdirectory to be created.
|
||||
"""
|
||||
sub_dir_path = self.create_file_or_folder_path(dir_name)
|
||||
|
|
|
@ -28,6 +28,7 @@ COL_VALUE = "value"
|
|||
def memory_in_gb(bytes: int) -> float:
|
||||
"""
|
||||
Converts a memory amount in bytes to gigabytes.
|
||||
|
||||
:param bytes:
|
||||
:return:
|
||||
"""
|
||||
|
@ -63,6 +64,7 @@ class GpuUtilization:
|
|||
def max(self, other: GpuUtilization) -> GpuUtilization:
|
||||
"""
|
||||
Computes the metric-wise maximum of the two GpuUtilization objects.
|
||||
|
||||
:param other:
|
||||
:return:
|
||||
"""
|
||||
|
@ -103,8 +105,9 @@ class GpuUtilization:
|
|||
"""
|
||||
Lists all metrics stored in the present object, as (metric_name, value) pairs suitable for logging in
|
||||
Tensorboard.
|
||||
|
||||
:param prefix: If provided, this string as used as an additional prefix for the metric name itself. If prefix
|
||||
is "max", the metric would look like "maxLoad_Percent"
|
||||
is "max", the metric would look like "maxLoad_Percent"
|
||||
:return: A list of (name, value) tuples.
|
||||
"""
|
||||
return [
|
||||
|
@ -118,6 +121,7 @@ class GpuUtilization:
|
|||
def from_gpu(gpu: GPU) -> GpuUtilization:
|
||||
"""
|
||||
Creates a GpuUtilization object from data coming from the gputil library.
|
||||
|
||||
:param gpu: GPU diagnostic data from gputil.
|
||||
:return:
|
||||
"""
|
||||
|
@ -145,10 +149,11 @@ class ResourceMonitor(Process):
|
|||
csv_results_folder: Path):
|
||||
"""
|
||||
Creates a process that will monitor CPU and GPU utilization.
|
||||
|
||||
:param interval_seconds: The interval in seconds at which usage statistics should be written.
|
||||
:param tensorboard_folder: The path in which to create a tensorboard logfile.
|
||||
:param csv_results_folder: The path in which the CSV file with aggregate metrics will be created.
|
||||
When running in AzureML, this should NOT reside inside the /logs folder.
|
||||
When running in AzureML, this should NOT reside inside the /logs folder.
|
||||
"""
|
||||
super().__init__(name="Resource Monitor", daemon=True)
|
||||
self._interval_seconds = interval_seconds
|
||||
|
@ -163,6 +168,7 @@ class ResourceMonitor(Process):
|
|||
def log_to_tensorboard(self, label: str, value: float) -> None:
|
||||
"""
|
||||
Write a scalar metric value to Tensorboard, marked with the present step.
|
||||
|
||||
:param label: The name of the metric.
|
||||
:param value: The value.
|
||||
"""
|
||||
|
@ -172,6 +178,7 @@ class ResourceMonitor(Process):
|
|||
"""
|
||||
Updates the stored GPU utilization metrics with the current status coming from gputil, and logs
|
||||
them to Tensorboard.
|
||||
|
||||
:param gpus: The current utilization information, read from gputil, for all available GPUs.
|
||||
"""
|
||||
for gpu in gpus:
|
||||
|
|
|
@ -14,12 +14,13 @@ def spawn_and_monitor_subprocess(process: str,
|
|||
"""
|
||||
Helper function to start a subprocess, passing in a given set of arguments, and monitor it.
|
||||
Returns the subprocess exit code and the list of lines written to stdout.
|
||||
|
||||
:param process: The name and path of the executable to spawn.
|
||||
:param args: The args to the process.
|
||||
:param env: The environment variables that the new process will run with. If not provided, copy the
|
||||
environment from the current process.
|
||||
environment from the current process.
|
||||
:return: Return code after the process has finished, and the list of lines that were written to stdout by the
|
||||
subprocess.
|
||||
subprocess.
|
||||
"""
|
||||
if env is None:
|
||||
env = dict(os.environ.items())
|
||||
|
|
|
@ -202,6 +202,7 @@ class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
|
|||
def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
|
||||
"""
|
||||
Implements val - train split.
|
||||
|
||||
:param val_split: proportion to use for validation
|
||||
:param seed: random seed for splitting
|
||||
:return: dataset_train, dataset_val
|
||||
|
|
|
@ -37,8 +37,10 @@ class InnerEyeVisionDataModule(VisionDataModule):
|
|||
:param dataset_cls: class to load the dataset. Expected to inherit from InnerEyeDataClassBaseWithReturnIndex and
|
||||
VisionDataset. See InnerEyeCIFAR10 for an example.
|
||||
BEWARE VisionDataModule expects the first positional argument of your class to be the data directory.
|
||||
|
||||
:param return_index: whether the return the index in __get_item__, the dataset_cls is expected to implement
|
||||
this logic.
|
||||
this logic.
|
||||
|
||||
:param train_transforms: transforms to use at training time
|
||||
:param val_transforms: transforms to use at validation time
|
||||
:param data_dir: data directory where to find the data
|
||||
|
@ -113,7 +115,7 @@ class CombinedDataModule(LightningDataModule):
|
|||
|
||||
:param encoder_module: datamodule to use for training of SSL.
|
||||
:param linear_head_module: datamodule to use for training of linear head on top of frozen encoder. Can use a
|
||||
different batch size than the encoder module. CombinedDataModule logic will take care of aggregation.
|
||||
different batch size than the encoder module. CombinedDataModule logic will take care of aggregation.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.encoder_module = encoder_module
|
||||
|
@ -151,6 +153,7 @@ class CombinedDataModule(LightningDataModule):
|
|||
"""
|
||||
Creates a CombinedLoader from the data loaders for the encoder and the linear head.
|
||||
The cycle mode is chosen such that in all cases the encoder dataset is only cycled through once.
|
||||
|
||||
:param encoder_loader: The dataloader to use for the SSL encoder.
|
||||
:param linear_head_loader: The dataloader to use for the linear head.
|
||||
"""
|
||||
|
|
|
@ -30,12 +30,14 @@ def get_ssl_transforms_from_config(config: CfgNode,
|
|||
|
||||
:param config: configuration defining which augmentations to apply as well as their intensities.
|
||||
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
|
||||
are called on. If False, simply return one transformed version of the sample centered and cropped.
|
||||
are called on. If False, simply return one transformed version of the sample centered and cropped.
|
||||
|
||||
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
|
||||
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
|
||||
(no augmentations)
|
||||
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
|
||||
(no augmentations)
|
||||
|
||||
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
"""
|
||||
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
|
||||
expand_channels=expand_channels)
|
||||
|
|
|
@ -63,6 +63,7 @@ def get_encoder_output_dim(
|
|||
) -> int:
|
||||
"""
|
||||
Calculates the output dimension of ssl encoder by making a single forward pass.
|
||||
|
||||
:param pl_module: pl encoder module
|
||||
:param dm: pl datamodule
|
||||
"""
|
||||
|
|
|
@ -203,10 +203,10 @@ class SSLContainer(LightningContainer):
|
|||
Returns torch lightning data module for encoder or linear head
|
||||
|
||||
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
|
||||
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
|
||||
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
|
||||
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
|
||||
than SSL training, hence CombinedDataloader might loop over data several times per epoch).
|
||||
:return: transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
|
||||
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
|
||||
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
|
||||
than SSL training, hence CombinedDataloader might loop over data several times per epoch).
|
||||
"""
|
||||
datamodule_args = self.datamodule_args[SSLDataModuleType.ENCODER] if is_ssl_encoder_module else \
|
||||
self.datamodule_args[SSLDataModuleType.LINEAR_HEAD]
|
||||
|
@ -237,12 +237,15 @@ class SSLContainer(LightningContainer):
|
|||
is_ssl_encoder_module: bool) -> Tuple[Any, Any]:
|
||||
"""
|
||||
Returns the transformation pipeline for training and validation.
|
||||
|
||||
:param augmentation_config: optional yaml config defining strength of augmenentations. Ignored for CIFAR
|
||||
examples.
|
||||
examples.
|
||||
|
||||
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
|
||||
pipeline to return.
|
||||
pipeline to return.
|
||||
|
||||
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
|
||||
applied on and it applies the training transformations also at validation time. Note that if your
|
||||
applied on and it applies the training transformations also at validation time. Note that if your
|
||||
transformation does not contain any randomness, the pipeline will return two identical copies. If False, it
|
||||
will return only one transformation.
|
||||
:return: training transformation pipeline and validation transformation pipeline.
|
||||
|
|
|
@ -19,11 +19,11 @@ class SSLClassifierContainer(SSLContainer):
|
|||
"""
|
||||
This module is used to train a linear classifier on top of a frozen (or not) encoder.
|
||||
|
||||
If you are running on AML, you can specify the SSL training run id via the --pretraining_run_recovery_id flag. This
|
||||
will automatically download the checkpoints for you and take the latest one as the starting weights of your
|
||||
If you are running on AML, you can specify the SSL training run id via the ``--pretraining_run_recovery_id`` flag.
|
||||
This will automatically download the checkpoints for you and take the latest one as the starting weights of your
|
||||
classifier.
|
||||
|
||||
If you are running locally, you can specify the path to your SSL weights via the --local_ssl_weights_path parameter.
|
||||
If you are running locally, you can specify the path to your SSL weights via the ``--local_ssl_weights_path`` flag.
|
||||
|
||||
See docs/self_supervised_models.md for more details.
|
||||
"""
|
||||
|
|
|
@ -40,15 +40,18 @@ class BYOLInnerEye(pl.LightningModule):
|
|||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Args:
|
||||
|
||||
:param num_samples: Number of samples present in training dataset / dataloader.
|
||||
:param learning_rate: Optimizer learning rate.
|
||||
:param batch_size: Sample batch size used in gradient updates.
|
||||
:param encoder_name: Type of CNN encoder used to extract image embeddings. The options are:
|
||||
{'resnet18', 'resnet50', 'resnet101', 'densenet121'}.
|
||||
|
||||
:param warmup_epochs: Number of epochs for scheduler warm up (linear increase from 0 to base_lr).
|
||||
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.
|
||||
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not
|
||||
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not
|
||||
shrink them.
|
||||
|
||||
:param weight_decay: L2-norm weight decay.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -78,10 +81,12 @@ class BYOLInnerEye(pl.LightningModule):
|
|||
"""
|
||||
Returns the BYOL loss for a given batch of images, used in validation
|
||||
and training step.
|
||||
|
||||
:param batch: assumed to be a batch a Tuple(List[tensor, tensor, tensor], tensor) to match lightning-bolts
|
||||
SimCLRTrainDataTransform API; the first tuple element contains a list of three tensor where the two first
|
||||
SimCLRTrainDataTransform API; the first tuple element contains a list of three tensor where the two first
|
||||
elements contain two are two strong augmented versions of the original images in the batch and the last
|
||||
is a milder augmentation (ignored here).
|
||||
|
||||
:param batch_idx: index of the batch
|
||||
:return: BYOL loss
|
||||
"""
|
||||
|
@ -115,7 +120,10 @@ class BYOLInnerEye(pl.LightningModule):
|
|||
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore
|
||||
|
||||
def configure_optimizers(self) -> Any:
|
||||
# exclude certain parameters
|
||||
"""Testing this out
|
||||
|
||||
:return: _description_
|
||||
"""
|
||||
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
|
||||
weight_decay=self.hparams.weight_decay) # type: ignore
|
||||
optimizer = Adam(parameters,
|
||||
|
|
|
@ -41,6 +41,7 @@ class SimCLRInnerEye(SimCLR):
|
|||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Returns SimCLR pytorch-lightning module, based on lightning-bolts implementation.
|
||||
|
||||
:param encoder_name: Image encoder name (predefined models)
|
||||
:param dataset_name: Dataset name (e.g. cifar10, kaggle, etc.)
|
||||
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.
|
||||
|
|
|
@ -38,6 +38,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
|
||||
:param class_weights: The class weights to use when computing the cross entropy loss. If set to None,
|
||||
no weighting will be done.
|
||||
|
||||
:param length_linear_head_loader: The maximum number of batches in the dataloader for the linear head.
|
||||
"""
|
||||
|
||||
|
@ -123,6 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
|
|||
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
|
||||
"""
|
||||
Moves batch to device.
|
||||
|
||||
:param device: device to move the batch to.
|
||||
"""
|
||||
_, x, y = batch
|
||||
|
|
|
@ -38,9 +38,10 @@ def load_yaml_augmentation_config(config_path: Path) -> CfgNode:
|
|||
def create_ssl_encoder(encoder_name: str, use_7x7_first_conv_in_resnet: bool = True) -> torch.nn.Module:
|
||||
"""
|
||||
Creates SSL encoder.
|
||||
|
||||
:param encoder_name: available choices: resnet18, resnet50, resnet101 and densenet121.
|
||||
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.
|
||||
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not shrink them.
|
||||
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not shrink them.
|
||||
"""
|
||||
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50, resnet101
|
||||
from InnerEye.ML.SSL.encoders import DenseNet121Encoder
|
||||
|
|
|
@ -20,9 +20,9 @@ def random_select_patch_center(sample: Sample, class_weights: List[float] = None
|
|||
class.
|
||||
|
||||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return numpy int array (3x1) containing patch center spatial coordinates
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop voxel belongs
|
||||
to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: numpy int array (3x1) containing patch center spatial coordinates
|
||||
"""
|
||||
num_classes = sample.labels.shape[0]
|
||||
|
||||
|
@ -70,9 +70,9 @@ def slicers_for_random_crop(sample: Sample,
|
|||
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
|
||||
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
|
||||
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
|
||||
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
|
||||
indices of the center point of the crop.
|
||||
indices of the center point of the crop.
|
||||
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
|
||||
"""
|
||||
shape = sample.image.shape[1:]
|
||||
|
|
|
@ -23,7 +23,7 @@ class RandomGamma:
|
|||
def __init__(self, scale: Tuple[float, float]) -> None:
|
||||
"""
|
||||
:param scale: a tuple (min_gamma, max_gamma) that specifies the range of possible values to sample the gamma
|
||||
value from when the transformation is called.
|
||||
value from when the transformation is called.
|
||||
"""
|
||||
self.scale = scale
|
||||
|
||||
|
@ -59,6 +59,7 @@ class AddGaussianNoise:
|
|||
def __init__(self, p_apply: float, std: float) -> None:
|
||||
"""
|
||||
Transformation to add Gaussian noise N(0, std) to an image.
|
||||
|
||||
:param: p_apply: probability of applying the transformation.
|
||||
:param: std: standard deviation of the gaussian noise to add to the image.
|
||||
"""
|
||||
|
|
|
@ -34,10 +34,11 @@ class ImageTransformationPipeline:
|
|||
use_different_transformation_per_channel: bool = False):
|
||||
"""
|
||||
:param transforms: List of transformations to apply to images. Supports out of the boxes torchvision transforms
|
||||
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
|
||||
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
|
||||
function should expect input of shape [C, Z, H, W] and apply the same transformation to each Z slice.
|
||||
|
||||
:param use_different_transformation_per_channel: if True, apply a different version of the augmentation pipeline
|
||||
for each channel. If False, applies the same transformation to each channel, separately.
|
||||
for each channel. If False, applies the same transformation to each channel, separately.
|
||||
"""
|
||||
self.use_different_transformation_per_channel = use_different_transformation_per_channel
|
||||
self.pipeline = Compose(transforms) if isinstance(transforms, List) else transforms
|
||||
|
@ -93,11 +94,13 @@ def create_transforms_from_config(config: CfgNode,
|
|||
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
|
||||
images but it can be used for other types of images data, type of augmentations to use and strength are
|
||||
expected to be defined in the config. The channel expansion is needed for gray images.
|
||||
|
||||
:param config: config yaml file fixing strength and type of augmentation to apply
|
||||
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
|
||||
disable augmentations i.e. only resize and center crop the image.
|
||||
disable augmentations i.e. only resize and center crop the image.
|
||||
|
||||
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
|
||||
"""
|
||||
transforms: List[Any] = []
|
||||
if expand_channels:
|
||||
|
|
|
@ -111,12 +111,12 @@ def download_and_compare_scores(outputs_folder: Path, azure_config: AzureConfig,
|
|||
"""
|
||||
:param azure_config: Azure configuration to use for downloading data
|
||||
:param comparison_blob_storage_paths: list of paths to directories containing metrics.csv and dataset.csv files,
|
||||
each of the form run_recovery_id/rest_of_path
|
||||
each of the form run_recovery_id/rest_of_path
|
||||
:param model_dataset_df: dataframe containing contents of dataset.csv for the current model
|
||||
:param model_metrics_df: dataframe containing contents of metrics.csv for the current model
|
||||
:return: a dataframe for all the data (current model and all baselines); whether any comparisons were
|
||||
done, i.e. whether a valid baseline was found; and the text lines to be written to the Wilcoxon results
|
||||
file.
|
||||
done, i.e. whether a valid baseline was found; and the text lines to be written to the Wilcoxon results
|
||||
file.
|
||||
"""
|
||||
comparison_baselines = get_comparison_baselines(outputs_folder, azure_config, comparison_blob_storage_paths)
|
||||
result = perform_score_comparisons(model_dataset_df, model_metrics_df, comparison_baselines)
|
||||
|
@ -195,10 +195,11 @@ def compare_files(expected: Path, actual: Path, csv_relative_tolerance: float =
|
|||
basis.
|
||||
|
||||
:param expected: A file that contains the expected contents. The type of comparison (text or binary) is chosen
|
||||
based on the extension of this file.
|
||||
based on the extension of this file.
|
||||
|
||||
:param actual: A file that contains the actual contents.
|
||||
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
:return: An empty string if the files appear identical, or otherwise an error message with details.
|
||||
"""
|
||||
|
||||
|
@ -257,9 +258,9 @@ def compare_folder_contents(expected_folder: Path,
|
|||
:param actual_folder: The output folder with the actually produced files.
|
||||
:param run: An AzureML run
|
||||
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
:return: A list of human readable error messages, with message and file path. If no errors are found, the list is
|
||||
empty.
|
||||
empty.
|
||||
"""
|
||||
messages = []
|
||||
if run and is_offline_run_context(run):
|
||||
|
@ -302,7 +303,7 @@ def compare_folders_and_run_outputs(expected: Path, actual: Path, csv_relative_t
|
|||
:param expected: A folder with files that are expected to be present.
|
||||
:param actual: The output folder with the actually produced files.
|
||||
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
If 0.0, do not allow any discrepancy.
|
||||
"""
|
||||
if not expected.is_dir():
|
||||
raise ValueError(f"Folder with expected files does not exist: {expected}")
|
||||
|
|
|
@ -86,6 +86,7 @@ def create_unique_timestamp_id() -> str:
|
|||
def get_best_checkpoint_path(path: Path) -> Path:
|
||||
"""
|
||||
Given a path and checkpoint, formats a path based on the checkpoint file name format.
|
||||
|
||||
:param path to checkpoint folder
|
||||
"""
|
||||
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
|
||||
|
|
|
@ -47,12 +47,12 @@ class CovidModel(ScalarModelBase):
|
|||
"""
|
||||
Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model.
|
||||
|
||||
For AML you need to provide the run_id of your SSL training job as a command line argument
|
||||
--pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your
|
||||
For AML you need to provide the run_id of your SSL training job as a command line argument:
|
||||
``--pretraining_run_recovery_id=<id_of_your_ssl_model>``. This will download the checkpoints of the run to your
|
||||
machine and load the corresponding pretrained model.
|
||||
|
||||
To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use the
|
||||
--name_of_checkpoint argument.
|
||||
To recover from a particular checkpoint from your SSL run e.g. ``"recovery_epoch=499.ckpt"`` please use the
|
||||
``--name_of_checkpoint`` argument.
|
||||
"""
|
||||
use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL."
|
||||
"If False, start training a DenseNet model from scratch"
|
||||
|
@ -242,9 +242,11 @@ class CovidModel(ScalarModelBase):
|
|||
Generate a custom report for the Covid model. This report will read the file model_output.csv generated for
|
||||
the training, validation or test sets and compute both the multiclass accuracy and the accuracy for each of the
|
||||
hierarchical tasks.
|
||||
|
||||
:param report_dir: Directory report is to be written to
|
||||
:param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different
|
||||
paths for single vs ensemble runs.)
|
||||
paths for single vs ensemble runs.)
|
||||
|
||||
"""
|
||||
|
||||
label_prefix = LoggingColumns.Label.value
|
||||
|
@ -364,6 +366,11 @@ class CovidModel(ScalarModelBase):
|
|||
|
||||
class DicomPreparation:
|
||||
def __call__(self, item: torch.Tensor) -> PIL.Image:
|
||||
"""Call class as a function. This will act as a transformation function for the dataset.
|
||||
|
||||
:param item: tensor to transform.
|
||||
:return: transformed data.
|
||||
"""
|
||||
# Item will be of dimension [C, Z, X, Y]
|
||||
images = item.numpy()
|
||||
assert images.shape[0] == 1 and images.shape[1] == 1
|
||||
|
|
|
@ -35,8 +35,9 @@ class HelloDataset(Dataset):
|
|||
def __init__(self, raw_data: List[List[float]]) -> None:
|
||||
"""
|
||||
Creates the 1-dim regression dataset.
|
||||
|
||||
:param raw_data: The raw data, e.g. from a cross validation split or loaded from file. This
|
||||
must be numeric data which can be converted into a tensor. See the static method
|
||||
must be numeric data which can be converted into a tensor. See the static method
|
||||
from_path_and_indexes for an example call.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -55,6 +56,7 @@ class HelloDataset(Dataset):
|
|||
end_index: int) -> 'HelloDataset':
|
||||
'''
|
||||
Static method to instantiate a HelloDataset from the root folder with the start and end indexes.
|
||||
|
||||
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
|
||||
:param start_index: The first row to read.
|
||||
:param end_index: The last row to read (exclusive)
|
||||
|
@ -137,6 +139,7 @@ class HelloRegression(LightningModule):
|
|||
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
|
||||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It runs a forward pass of a tensor through the model.
|
||||
|
||||
:param x: The input tensor(s)
|
||||
:return: The model output.
|
||||
"""
|
||||
|
@ -148,6 +151,7 @@ class HelloRegression(LightningModule):
|
|||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It consumes a minibatch of training data (coming out of the data loader), does forward propagation, and
|
||||
computes the loss.
|
||||
|
||||
:param batch: The batch of training data
|
||||
:return: The loss value with a computation graph attached.
|
||||
"""
|
||||
|
@ -162,6 +166,7 @@ class HelloRegression(LightningModule):
|
|||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It consumes a minibatch of validation data (coming out of the data loader), does forward propagation, and
|
||||
computes the loss.
|
||||
|
||||
:param batch: The batch of validation data
|
||||
:return: The loss value on the validation data.
|
||||
"""
|
||||
|
@ -173,6 +178,7 @@ class HelloRegression(LightningModule):
|
|||
"""
|
||||
This is a convenience method to reduce code duplication, because training, validation, and test step share
|
||||
large amounts of code.
|
||||
|
||||
:param batch: The batch of data to process, with input data and targets.
|
||||
:return: The MSE loss that the model achieved on this batch.
|
||||
"""
|
||||
|
@ -207,6 +213,7 @@ class HelloRegression(LightningModule):
|
|||
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
|
||||
It evaluates the model in "inference mode" on data coming from the test set. It could, for example,
|
||||
also write each model prediction to disk.
|
||||
|
||||
:param batch: The batch of test data.
|
||||
:param batch_idx: The index (0, 1, ...) of the batch when the data loader is enumerated.
|
||||
:return: The loss on the test data.
|
||||
|
|
|
@ -47,11 +47,12 @@ def get_fastmri_data_module(azure_dataset_id: str,
|
|||
Creates a LightningDataModule that consumes data from the FastMRI challenge. The type of challenge
|
||||
(single/multicoil) is determined from the name of the dataset in Azure blob storage. The mask type is set to
|
||||
equispaced, with 4x acceleration.
|
||||
|
||||
:param azure_dataset_id: The name of the dataset (folder name in blob storage).
|
||||
:param local_dataset: The local folder at which the dataset has been mounted or downloaded.
|
||||
:param sample_rate: Fraction of slices of the training data split to use. Set to a value <1.0 for rapid prototyping.
|
||||
:param test_path: The name of the folder inside the dataset that contains the test data.
|
||||
:return: A LightningDataModule object.
|
||||
:return: The FastMRI LightningDataModule object.
|
||||
"""
|
||||
if not azure_dataset_id:
|
||||
raise ValueError("The azure_dataset_id argument must be provided.")
|
||||
|
|
|
@ -47,17 +47,23 @@ class HeadAndNeckBase(SegmentationModelBase):
|
|||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param ground_truth_ids: List of ground truth ids.
|
||||
:param ground_truth_ids_display_names: Optional list of ground truth id display names. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param colours: Optional list of colours. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param fill_holes: Optional list of fill hole flags. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param roi_interpreted_types: Optional list of roi_interpreted_types. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param class_weights: Optional list of class weights. If
|
||||
present then must be of the same length as ground_truth_ids + 1.
|
||||
present then must be of the same length as ground_truth_ids + 1.
|
||||
|
||||
:param slice_exclusion_rules: Optional list of SliceExclusionRules.
|
||||
:param summed_probability_rules: Optional list of SummedProbabilityRule.
|
||||
:param num_feature_channels: Optional number of feature channels.
|
||||
|
|
|
@ -29,6 +29,7 @@ class HeadAndNeckPaper(HeadAndNeckBase):
|
|||
def __init__(self, num_structures: Optional[int] = None, **kwargs: Any) -> None:
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param num_structures: number of structures from STRUCTURE_LIST to predict (default: all structures)
|
||||
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
|
||||
"""
|
||||
|
|
|
@ -94,6 +94,7 @@ class HelloWorld(SegmentationModelBase):
|
|||
https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-tune-hyperparameters
|
||||
A reference is provided at https://docs.microsoft.com/en-us/python/api/azureml-train-core/azureml.train
|
||||
.hyperdrive?view=azure-ml-py
|
||||
|
||||
:param run_config: The configuration for running an individual experiment.
|
||||
:return: An Azure HyperDrive run configuration (configured PyTorch environment).
|
||||
"""
|
||||
|
|
|
@ -28,17 +28,23 @@ class ProstateBase(SegmentationModelBase):
|
|||
**kwargs: Any) -> None:
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param ground_truth_ids: List of ground truth ids.
|
||||
:param ground_truth_ids_display_names: Optional list of ground truth id display names. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param colours: Optional list of colours. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param fill_holes: Optional list of fill hole flags. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param interpreted_types: Optional list of interpreted_types. If
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
present then must be of the same length as ground_truth_ids.
|
||||
|
||||
:param class_weights: Optional list of class weights. If
|
||||
present then must be of the same length as ground_truth_ids + 1.
|
||||
present then must be of the same length as ground_truth_ids + 1.
|
||||
|
||||
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
|
||||
"""
|
||||
ground_truth_ids_display_names = ground_truth_ids_display_names or [f"zz_{name}" for name in ground_truth_ids]
|
||||
|
|
|
@ -20,6 +20,7 @@ class ProstatePaper(ProstateBase):
|
|||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
|
||||
"""
|
||||
ground_truth_ids = fg_classes
|
||||
|
|
|
@ -91,7 +91,7 @@ class PyTorchPassthroughModel(BaseSegmentationModel):
|
|||
Ignore the actual patches and return a fixed segmentation, explained in make_nesting_rectangles.
|
||||
|
||||
:param patches: Set of patches, of shape (#patches, #image_channels, Z, Y, X). Only the shape
|
||||
is used.
|
||||
is used.
|
||||
:return: Fixed tensor of shape (#patches, number_of_classes, Z, Y, Z).
|
||||
"""
|
||||
output_size: TupleInt3 = (patches.shape[2], patches.shape[3], patches.shape[4])
|
||||
|
|
|
@ -57,6 +57,7 @@ class CroppingDataset(FullImageDataset):
|
|||
"""
|
||||
Pad the original sample such the the provided images has the same
|
||||
(or slightly larger in case of uneven difference) shape to the output_size, using the provided padding mode.
|
||||
|
||||
:param sample: Sample to pad.
|
||||
:param crop_size: Crop size to match.
|
||||
:param padding_mode: The padding scheme to apply.
|
||||
|
@ -89,10 +90,12 @@ class CroppingDataset(FullImageDataset):
|
|||
class_weights: Optional[List[float]] = None) -> CroppedSample:
|
||||
"""
|
||||
Creates an instance of a cropped sample extracted from full 3D images.
|
||||
|
||||
:param sample: the full size 3D sample to use for extracting a cropped sample.
|
||||
:param crop_size: the size of the crop to extract.
|
||||
:param center_size: the size of the center of the crop (this should be the same as the spatial dimensions
|
||||
of the posteriors that the model produces)
|
||||
|
||||
:param class_weights: the distribution to use for the crop center class.
|
||||
:return: CroppedSample
|
||||
"""
|
||||
|
|
|
@ -31,6 +31,7 @@ def collate_with_metadata(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
|||
The collate function that the dataloader workers should use. It does the same thing for all "normal" fields
|
||||
(all fields are put into tensors with outer dimension batch_size), except for the special "metadata" field.
|
||||
Those metadata objects are collated into a simple list.
|
||||
|
||||
:param batch: A list of samples that should be collated.
|
||||
:return: collated result
|
||||
"""
|
||||
|
@ -84,7 +85,7 @@ class ImbalancedSampler(Sampler):
|
|||
"""
|
||||
|
||||
:param dataset: a dataset
|
||||
:num_samples: number of samples to draw. If None the number of samples
|
||||
:num_samples: number of samples to draw. If None the number of samples
|
||||
corresponds to the length of the dataset.
|
||||
"""
|
||||
self.dataset = dataset
|
||||
|
@ -123,6 +124,7 @@ class RepeatDataLoader(DataLoader):
|
|||
**kwargs: Any):
|
||||
"""
|
||||
Creates a new data loader.
|
||||
|
||||
:param dataset: The dataset that should be loaded.
|
||||
:param batch_size: The number of samples per minibatch.
|
||||
:param shuffle: If true, the dataset will be shuffled randomly.
|
||||
|
@ -204,11 +206,12 @@ class FullImageDataset(GeneralDataset):
|
|||
"""
|
||||
Dataset class that loads and creates samples with full 3D images from a given pd.Dataframe. The following
|
||||
are the operations performed to generate a sample from this dataset:
|
||||
-------------------------------------------------------------------------------------------------
|
||||
|
||||
1) On initialization parses the provided pd.Dataframe with dataset information, to cache the set of file paths
|
||||
and patient mappings to load as PatientDatasetSource. The sources are then saved in a list: dataset_sources.
|
||||
2) dataset_sources is iterated in a batched fashion, where for each batch it loads the full 3D images, and applies
|
||||
pre-processing functions (e.g. normalization), returning a sample that can be used for full image operations.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame,
|
||||
|
@ -313,6 +316,7 @@ def load_dataset_sources(dataframe: pd.DataFrame,
|
|||
The dataframe contains per-patient per-channel image information, relative to a root directory.
|
||||
This method converts that into a per-patient dictionary, that contains absolute file paths
|
||||
separated for for image channels, ground truth channels, and mask channels.
|
||||
|
||||
:param dataframe: A dataframe read directly from a dataset CSV file.
|
||||
:param local_dataset_root_folder: The root folder that contains all images.
|
||||
:param image_channels: The names of the image channels that should be used in the result.
|
||||
|
|
|
@ -40,6 +40,7 @@ class PatientMetadata:
|
|||
For each of the columns "seriesId", "instituionId" and "tags", the distinct values for the given patient are
|
||||
computed. If there is exactly 1 distinct value, that is returned as the respective patient metadata. If there is
|
||||
more than 1 distinct value, the metadata column is set to None.
|
||||
|
||||
:param dataframe: The dataset to read from.
|
||||
:param patient_id: The ID of the patient for which the metadata should be extracted.
|
||||
:return: An instance of PatientMetadata for the given patient_id
|
||||
|
@ -101,8 +102,9 @@ class SampleBase:
|
|||
def from_dict(cls: Type[T], sample: Dict[str, Any]) -> T:
|
||||
"""
|
||||
Create an instance of the sample class, based on the provided sample dictionary
|
||||
|
||||
:param sample: dictionary of arguments
|
||||
:return:
|
||||
:return: an instance of the SampleBase class
|
||||
"""
|
||||
return cls(**sample) # type: ignore
|
||||
|
||||
|
@ -110,6 +112,7 @@ class SampleBase:
|
|||
"""
|
||||
Create a clone of the current sample, with the provided overrides to replace the
|
||||
existing properties if they exist.
|
||||
|
||||
:param overrides:
|
||||
:return:
|
||||
"""
|
||||
|
@ -118,7 +121,6 @@ class SampleBase:
|
|||
def get_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current sample as a dictionary of property names and their values.
|
||||
:return:
|
||||
"""
|
||||
return vars(self)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
|
|||
Converts a string from a dataset.csv file that contains a model's label to a scalar.
|
||||
|
||||
For classification datasets:
|
||||
If num_classes is 1 (binary classification tasks):
|
||||
If num_classes is 1 (binary classification tasks)
|
||||
The function maps ["1", "true", "yes"] to [1], ["0", "false", "no"] to [0].
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
|
||||
If num_classes is greater than 1 (multilabel datasets):
|
||||
|
@ -42,17 +42,17 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
|
|||
map "1|3|4" to [0, 1, 0, 1, 1, 0]).
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string,
|
||||
this function returns an all-zero tensor (none of the label classes were positive for this sample).
|
||||
|
||||
For regression datasets:
|
||||
The function casts a string label to float. Raises an exception if the conversion is
|
||||
not possible.
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
|
||||
The function casts a string label to float. Raises an exception if the conversion is
|
||||
not possible.
|
||||
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
|
||||
|
||||
:param label_string: The value of the label as read from CSV via a DataFrame.
|
||||
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
|
||||
:param num_classes: Number of classes. This should be equal the size of the model output.
|
||||
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes should
|
||||
correspond to the number of label classes in the problem.
|
||||
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes
|
||||
should correspond to the number of label classes in the problem.
|
||||
|
||||
:param is_classification_dataset: If the model is a classification model
|
||||
:return: A list of floats with the same size as num_classes
|
||||
"""
|
||||
|
@ -120,9 +120,11 @@ def _get_single_channel_row(subject_rows: pd.DataFrame,
|
|||
'channel' argument. Throws a ValueError if there is no or more than 1 such row.
|
||||
The result is returned as a dictionary, not a DataFrame!
|
||||
If the 'channel' argument is null, the input is expected to be already 1 row, which is returned as a dictionary.
|
||||
|
||||
:param subject_rows: A set of rows all belonging to the same subject.
|
||||
:param channel: The value to look for in the `channel_column` column. This can be null. If it is null,
|
||||
the input `subject_rows` is expected to have exactly 1 row.
|
||||
the input `subject_rows` is expected to have exactly 1 row.
|
||||
|
||||
:param subject_id: A string describing the presently processed subject. This is only used for error reporting.
|
||||
:return: A dictionary mapping from column names to values, created from the unique row that was found.
|
||||
"""
|
||||
|
@ -144,6 +146,7 @@ def _string_to_float(text: Union[str, float], error_message_prefix: str = None)
|
|||
"""
|
||||
Converts a string coming from a dataset.csv file to a floating point number, taking into account all the
|
||||
corner cases that can happen when the dataset file is malformed.
|
||||
|
||||
:param text: The element coming from the dataset.csv file.
|
||||
:param error_message_prefix: A prefix string that will go into the error message if the conversion fails.
|
||||
:return: A floating point number, possibly np.nan.
|
||||
|
@ -181,27 +184,32 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
"""
|
||||
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
|
||||
labels, the non-image features, and the paths to the image files.
|
||||
|
||||
:param num_classes: Number of classes, this is equivalent to model output tensor size
|
||||
:param channel_column: The name of the column that contains the row identifier ("channels")
|
||||
:param metadata_columns: A list of columns that well be added to the item metadata as key/value pairs.
|
||||
:param subject_rows: All dataset rows that belong to the same subject.
|
||||
:param subject_id: The identifier of the subject that is being processed.
|
||||
:param image_channels: The names of all channels (stored in the CSV_CHANNEL_HEADER column of the dataframe)
|
||||
that are expected to be loaded from disk later because they are large images.
|
||||
that are expected to be loaded from disk later because they are large images.
|
||||
|
||||
:param image_file_column: The name of the column that contains the image file names.
|
||||
:param label_channels: The name of the channel where the label scalar or vector is read from.
|
||||
:param label_value_column: The column that contains the value for the label scalar or vector.
|
||||
:param non_image_feature_channels: non_image_feature_channels: A dictonary of the names of all channels where
|
||||
additional scalar values should be read from. THe keys should map each feature to its channels.
|
||||
additional scalar values should be read from. THe keys should map each feature to its channels.
|
||||
|
||||
:param numerical_columns: The names of all columns where additional scalar values should be read from.
|
||||
:param categorical_data_encoder: Encoding scheme for categorical data.
|
||||
:param is_classification_dataset: If True, the dataset will be used in a classification model. If False,
|
||||
assume that the dataset will be used in a regression model.
|
||||
assume that the dataset will be used in a regression model.
|
||||
|
||||
:param transform_labels: a label transformation or a list of label transformation to apply to the labels.
|
||||
If a list is provided, the transformations are applied in order from left to right.
|
||||
If a list is provided, the transformations are applied in order from left to right.
|
||||
|
||||
:param sequence_position_numeric: Numeric position of the data source in a data sequence. Assumed to be
|
||||
a non-sequential dataset item if None provided (default).
|
||||
:return:
|
||||
a non-sequential dataset item if None provided (default).
|
||||
:return: A ScalarDataSource containing the specified data.
|
||||
"""
|
||||
|
||||
def _get_row_for_channel(channel: Optional[str]) -> Dict[str, str]:
|
||||
|
@ -234,6 +242,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
|
|||
"""
|
||||
Return either the list of channels for a given column or if None was passed as
|
||||
numerical channels i.e. there are no channel to be specified return [None].
|
||||
|
||||
:param non_image_channels: Dict mapping features name to their channels
|
||||
:param feature: feature name for which to return the channels
|
||||
:return: List of channels for the given feature.
|
||||
|
@ -351,15 +360,19 @@ class DataSourceReader():
|
|||
:param image_channels: The names of all channels (stored in the CSV_CHANNEL_HEADER column of the dataframe)
|
||||
:param label_channels: The name of the channel where the label scalar or vector is read from.
|
||||
:param transform_labels: a label transformation or a list of label transformation to apply to the labels.
|
||||
If a list is provided, the transformations are applied in order from left to right.
|
||||
If a list is provided, the transformations are applied in order from left to right.
|
||||
|
||||
:param non_image_feature_channels: non_image_feature_channels: A dictionary of the names of all channels where
|
||||
additional scalar values should be read from. The keys should map each feature to its channels.
|
||||
additional scalar values should be read from. The keys should map each feature to its channels.
|
||||
|
||||
:param numerical_columns: The names of all columns where additional scalar values should be read from.
|
||||
:param sequence_column: The name of the column that contains the sequence index, that will be stored in
|
||||
metadata.sequence_position. If this column name is not provided, the sequence_position will be 0.
|
||||
metadata.sequence_position. If this column name is not provided, the sequence_position will be 0.
|
||||
|
||||
:param subject_column: The name of the column that contains the subject identifier
|
||||
:param channel_column: The name of the column that contains the row identifier ("channels")
|
||||
that are expected to be loaded from disk later because they are large images.
|
||||
that are expected to be loaded from disk later because they are large images.
|
||||
|
||||
:param is_classification_dataset: If the current dataset is classification or not.
|
||||
:param categorical_data_encoder: Encoding scheme for categorical data.
|
||||
"""
|
||||
|
@ -422,6 +435,7 @@ class DataSourceReader():
|
|||
"""
|
||||
Loads dataset items from the given dataframe, where all column and channel configurations are taken from their
|
||||
respective model config elements.
|
||||
|
||||
:param data_frame: The dataframe to read dataset items from.
|
||||
:param args: The model configuration object.
|
||||
:return: A list of all dataset items that could be read from the dataframe.
|
||||
|
@ -452,13 +466,13 @@ class DataSourceReader():
|
|||
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Extracts information from a dataframe to create a list of ClassificationItem. This will create one entry per
|
||||
unique
|
||||
value of subject_id in the dataframe. The file is structured around "channels", indicated by specific values in
|
||||
the CSV_CHANNEL_HEADER column. The result contains paths to image files, a label vector, and a matrix of
|
||||
additional values that are specified by rows and columns given in non_image_feature_channels and
|
||||
unique value of subject_id in the dataframe. The file is structured around "channels", indicated by specific
|
||||
values in the CSV_CHANNEL_HEADER column. The result contains paths to image files, a label vector, and a matrix
|
||||
of additional values that are specified by rows and columns given in non_image_feature_channels and
|
||||
numerical_columns.
|
||||
|
||||
:param num_dataset_reader_workers: Number of worker processes to use, if 0 then single threaded execution,
|
||||
otherwise if -1 then multiprocessing with all available cpus will be used.
|
||||
otherwise if -1 then multiprocessing with all available cpus will be used.
|
||||
:return: A list of ScalarDataSource or SequenceDataSource instances
|
||||
"""
|
||||
subject_ids = self.data_frame[self.subject_column].unique()
|
||||
|
@ -512,9 +526,9 @@ def files_by_stem(root_path: Path) -> Dict[str, Path]:
|
|||
"""
|
||||
Lists all files under the given root directory recursively, and returns a mapping from file name stem to full path.
|
||||
The file name stem is computed more restrictively than what Path.stem returns: file.nii.gz will use "file" as the
|
||||
stem, not "file.nii" as Path.stem would.
|
||||
Only actual files are returned in the mapping, no directories.
|
||||
If there are multiple files that map to the same stem, the function raises a ValueError.
|
||||
stem, not "file.nii" as Path.stem would. Only actual files are returned in the mapping, no directories. If there are
|
||||
multiple files that map to the same stem, the function raises a ValueError.
|
||||
|
||||
:param root_path: The root directory from which the file search should start.
|
||||
:return: A dictionary mapping from file name stem to the full path to where the file is found.
|
||||
"""
|
||||
|
@ -546,11 +560,13 @@ def is_valid_item_index(item: ScalarDataSource,
|
|||
min_sequence_position_value: int = 0) -> bool:
|
||||
"""
|
||||
Returns True if the item metadata in metadata.sequence_position is a valid sequence index.
|
||||
|
||||
:param item: The item to check.
|
||||
:param min_sequence_position_value: Check if the item has a metadata.sequence_position that is at least
|
||||
the value given here. Default is 0.
|
||||
the value given here. Default is 0.
|
||||
|
||||
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
:return: True if the item has a valid index.
|
||||
"""
|
||||
# If no max_sequence_position_value is given, we don't care about
|
||||
|
@ -572,9 +588,11 @@ def filter_valid_classification_data_sources_items(items: Iterable[ScalarDataSou
|
|||
|
||||
:param items: The list of items to filter.
|
||||
:param min_sequence_position_value: Restrict the data to items with a metadata.sequence_position that is at least
|
||||
the value given here. Default is 0.
|
||||
the value given here. Default is 0.
|
||||
|
||||
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
end with. Longer sequences will be truncated. None is default.
|
||||
|
||||
:param file_to_path_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
:return: A list of items, all of which are valid now.
|
||||
"""
|
||||
|
@ -637,9 +655,10 @@ class ScalarItemAugmentation:
|
|||
segmentation_transform: Optional[Callable] = None) -> None:
|
||||
"""
|
||||
:param image_transform: transformation function to apply to images field. If None, images field is unchanged by
|
||||
call.
|
||||
call.
|
||||
|
||||
:param segmentation_transform: transformation function to apply to segmentations field. If None segmentations
|
||||
field is unchanged by call.
|
||||
field is unchanged by call.
|
||||
"""
|
||||
self.image_transform = image_transform
|
||||
self.segmentation_transform = segmentation_transform
|
||||
|
@ -671,7 +690,8 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
|
|||
name: Optional[str] = None,
|
||||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
High level class for the scalar dataset designed to be inherited for specific behaviour
|
||||
High level class for the scalar dataset designed to be inherited for specific behaviour.
|
||||
|
||||
:param args: The model configuration object.
|
||||
:param data_frame: The dataframe to read from.
|
||||
:param feature_statistics: If given, the normalization factor for the non-image features is taken
|
||||
|
@ -691,7 +711,8 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
|
|||
def load_all_data_sources(self) -> List[ScalarDataSource]:
|
||||
"""
|
||||
Uses the dataframe to create data sources to be used by the dataset.
|
||||
:return:
|
||||
|
||||
:return: List of data sources.
|
||||
"""
|
||||
all_data_sources = DataSourceReader.load_data_sources_as_per_config(self.data_frame, self.args) # type: ignore
|
||||
self.status += f"Loading: {self.create_status_string(all_data_sources)}"
|
||||
|
@ -722,9 +743,10 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
|
|||
"""
|
||||
Loads the images and/or segmentations as given in the ClassificationDataSource item and
|
||||
applying the optional transformation specified by the class.
|
||||
|
||||
:param item: The item to load.
|
||||
:return: A ClassificationItem instances with the loaded images, and the labels and non-image features copied
|
||||
from the argument.
|
||||
from the argument.
|
||||
"""
|
||||
sample = item.load_images(
|
||||
root_path=self.args.local_dataset,
|
||||
|
@ -738,6 +760,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
|
|||
def create_status_string(self, items: List[ScalarDataSource]) -> str:
|
||||
"""
|
||||
Creates a human readable string that contains the number of items, and the distinct number of subjects.
|
||||
|
||||
:param items: Use the items provided to create the string
|
||||
:return: A string like "12 items for 5 subjects"
|
||||
"""
|
||||
|
@ -757,10 +780,12 @@ class ScalarDataset(ScalarDatasetBase):
|
|||
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
|
||||
"""
|
||||
Creates a new scalar dataset from a dataframe.
|
||||
|
||||
:param args: The model configuration object.
|
||||
:param data_frame: The dataframe to read from.
|
||||
:param feature_statistics: If given, the normalization factor for the non-image features is taken
|
||||
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
|
||||
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
|
||||
|
||||
:param sample_transform: Sample transforms that should be applied.
|
||||
:param name: Name of the dataset, used for diagnostics logging
|
||||
"""
|
||||
|
@ -802,6 +827,7 @@ class ScalarDataset(ScalarDatasetBase):
|
|||
one class index. The value stored will be the number of samples that belong to the positive class.
|
||||
In the multilabel case, this returns a dictionary with class indices and samples per class as the key-value
|
||||
pairs.
|
||||
|
||||
:return: Dictionary of {class_index: count}
|
||||
"""
|
||||
all_labels = [torch.flatten(torch.nonzero(item.label).int()).tolist() for item in self.items] # [N, 1]
|
||||
|
|
|
@ -39,7 +39,6 @@ class ScalarItemBase(SampleBase):
|
|||
def id(self) -> str:
|
||||
"""
|
||||
Gets the identifier of the present object from metadata.
|
||||
:return:
|
||||
"""
|
||||
return self.metadata.id # type: ignore
|
||||
|
||||
|
@ -47,7 +46,6 @@ class ScalarItemBase(SampleBase):
|
|||
def props(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets the general metadata dictionary for the present object.
|
||||
:return:
|
||||
"""
|
||||
return self.metadata.props # type: ignore
|
||||
|
||||
|
@ -94,6 +92,7 @@ class ScalarItem(ScalarItemBase):
|
|||
"""
|
||||
Creates a copy of the present object where all tensors live on the given CUDA device.
|
||||
The metadata field is left unchanged.
|
||||
|
||||
:param device: The CUDA or GPU device to move to.
|
||||
:return: A new `ScalarItem` with all tensors on the chosen device.
|
||||
"""
|
||||
|
@ -124,16 +123,19 @@ class ScalarDataSource(ScalarItemBase):
|
|||
root_path argument, or it must contain a file name stem only (without extension). In this case, the actual
|
||||
mapping from file name stem to full path is expected in the file_mapping argument.
|
||||
Either of 'root_path' or 'file_mapping' must be provided.
|
||||
|
||||
:param root_path: The root path where all channel files for images are expected. This is ignored if
|
||||
file_mapping is given.
|
||||
file_mapping is given.
|
||||
|
||||
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
:param load_segmentation: If True it loads segmentation if present on the same file as the image.
|
||||
:param center_crop_size: If supplied, all loaded images will be cropped to the size given here. The crop will
|
||||
be taken from the center of the image.
|
||||
be taken from the center of the image.
|
||||
|
||||
:param image_size: If given, all loaded images will be reshaped to the size given here, prior to the
|
||||
center crop.
|
||||
center crop.
|
||||
:return: An instance of ClassificationItem, with the same label and numerical_non_image_features fields,
|
||||
and all images loaded.
|
||||
and all images loaded.
|
||||
"""
|
||||
full_channel_files = self.get_all_image_filepaths(root_path=root_path,
|
||||
file_mapping=file_mapping)
|
||||
|
@ -156,8 +158,9 @@ class ScalarDataSource(ScalarItemBase):
|
|||
"""
|
||||
Checks if all file paths and non-image features are present in the object. All image channel files must
|
||||
be not None, and none of the non imaging features may be NaN or infinity.
|
||||
|
||||
:return: True if channel files is a list with not-None entries, and all non imaging features are finite
|
||||
floating point numbers.
|
||||
floating point numbers.
|
||||
"""
|
||||
return self.files_valid() and super().is_valid()
|
||||
|
||||
|
@ -169,8 +172,10 @@ class ScalarDataSource(ScalarItemBase):
|
|||
file_mapping: Optional[Dict[str, Path]]) -> List[Path]:
|
||||
"""
|
||||
Get a list of image paths for the object. Either root_path or file_mapping must be specified.
|
||||
|
||||
:param root_path: The root path where all channel files for images are expected. This is ignored if
|
||||
file_mapping is given.
|
||||
file_mapping is given.
|
||||
|
||||
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
"""
|
||||
full_channel_files: List[Path] = []
|
||||
|
@ -188,9 +193,11 @@ class ScalarDataSource(ScalarItemBase):
|
|||
"""
|
||||
Get the full path of an image file given the path relative to the dataset folder and one of
|
||||
root_path or file_mapping.
|
||||
|
||||
:param file: Image filepath relative to the dataset folder
|
||||
:param root_path: The root path where all channel files for images are expected. This is ignored if
|
||||
file_mapping is given.
|
||||
file_mapping is given.
|
||||
|
||||
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
|
||||
"""
|
||||
if file is None:
|
||||
|
|
|
@ -117,14 +117,15 @@ class DeepLearningFileSystemConfig(Parameterized):
|
|||
Creates a new object that holds output folder configurations. When running inside of AzureML, the output
|
||||
folders will be directly under the project root. If not running inside AzureML, a folder with a timestamp
|
||||
will be created for all outputs and logs.
|
||||
|
||||
:param project_root: The root folder that contains the code that submitted the present training run.
|
||||
When running inside the InnerEye repository, it is the git repo root. When consuming InnerEye as a package,
|
||||
this should be the root of the source code that calls the package.
|
||||
When running inside the InnerEye repository, it is the git repo root. When consuming InnerEye as a package,
|
||||
this should be the root of the source code that calls the package.
|
||||
:param is_offline_run: If true, this is a run outside AzureML. If False, it is inside AzureML.
|
||||
:param model_name: The name of the model that is trained. This is used to generate a run-specific output
|
||||
folder.
|
||||
folder.
|
||||
:param output_to: If provided, the output folders will be created as a subfolder of this argument. If not
|
||||
given, the output folders will be created inside of the project root.
|
||||
given, the output folders will be created inside of the project root.
|
||||
"""
|
||||
if not project_root.is_absolute():
|
||||
raise ValueError(f"The project root is required to be an absolute path, but got {project_root}")
|
||||
|
@ -165,6 +166,7 @@ class DeepLearningFileSystemConfig(Parameterized):
|
|||
"""
|
||||
Creates a new output folder configuration, where both outputs and logs go into the given subfolder inside
|
||||
the present outputs folder.
|
||||
|
||||
:param subfolder: The subfolder that should be created.
|
||||
:return:
|
||||
"""
|
||||
|
@ -281,6 +283,7 @@ class WorkflowParams(param.Parameterized):
|
|||
data_split: ModelExecutionMode) -> bool:
|
||||
"""
|
||||
Returns True if inference is required for this model_proc (single or ensemble) and data_split (Train/Val/Test).
|
||||
|
||||
:param model_proc: Whether we are testing an ensemble or single model.
|
||||
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
|
||||
:return: True if inference required.
|
||||
|
@ -442,6 +445,7 @@ class OutputParams(param.Parameterized):
|
|||
def set_output_to(self, output_to: PathOrString) -> None:
|
||||
"""
|
||||
Adjusts the file system settings in the present object such that all outputs are written to the given folder.
|
||||
|
||||
:param output_to: The absolute path to a folder that should contain the outputs.
|
||||
"""
|
||||
if isinstance(output_to, Path):
|
||||
|
@ -453,6 +457,7 @@ class OutputParams(param.Parameterized):
|
|||
"""
|
||||
Creates new file system settings (outputs folder, logs folder) based on the information stored in the
|
||||
present object. If any of the folders do not yet exist, they are created.
|
||||
|
||||
:param project_root: The root folder for the codebase that triggers the training run.
|
||||
"""
|
||||
self.file_system_config = DeepLearningFileSystemConfig.create(
|
||||
|
@ -783,6 +788,7 @@ class DeepLearningConfig(WorkflowParams,
|
|||
def dataset_data_frame(self, data_frame: Optional[DataFrame]) -> None:
|
||||
"""
|
||||
Sets the pandas data frame that the model uses.
|
||||
|
||||
:param data_frame: The data frame to set.
|
||||
"""
|
||||
self._dataset_data_frame = data_frame
|
||||
|
@ -844,12 +850,13 @@ class DeepLearningConfig(WorkflowParams,
|
|||
See https://pytorch.org/tutorials/beginner/saving_loading_models.html#warmstarting-model-using-parameters
|
||||
-from-a-different-model
|
||||
for an explanation on why strict=False is useful when loading parameters from other models.
|
||||
|
||||
:param path_to_checkpoint: Path to the checkpoint file.
|
||||
:return: Dictionary with model and optimizer state dicts. The dict should have at least the following keys:
|
||||
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
|
||||
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
|
||||
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
|
||||
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
|
||||
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
|
||||
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
|
||||
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
|
||||
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
|
||||
"""
|
||||
return load_checkpoint(path_to_checkpoint=path_to_checkpoint, use_gpu=self.use_gpu)
|
||||
|
||||
|
|
|
@ -275,6 +275,7 @@ class InnerEyeLightning(LightningModule):
|
|||
def validation_epoch_end(self, outputs: List[Any]) -> None:
|
||||
"""
|
||||
Resets the random number generator state to what it was before the current validation epoch started.
|
||||
|
||||
:param outputs: The list of outputs from the individual validation minibatches.
|
||||
"""
|
||||
# reset the random state for training, so that we get continue from where we were before the validation step.
|
||||
|
@ -315,12 +316,13 @@ class InnerEyeLightning(LightningModule):
|
|||
floating point, it is converted to a Tensor on the current device to enable synchronization.
|
||||
|
||||
:param sync_dist_override: If not None, use this value for the sync_dist argument to self.log. If None,
|
||||
set it automatically depending on the use of DDP.
|
||||
set it automatically depending on the use of DDP.
|
||||
|
||||
:param name: The name of the metric to log
|
||||
:param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
|
||||
:param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
|
||||
:param reduce_fx: The reduce function to use when synchronizing the tensors across GPUs. This must be
|
||||
a value recognized by sync_ddp: "sum", "mean"
|
||||
a value recognized by sync_ddp: "sum", "mean"
|
||||
"""
|
||||
metric_name = name if isinstance(name, str) else name.value
|
||||
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
|
||||
|
@ -334,10 +336,11 @@ class InnerEyeLightning(LightningModule):
|
|||
"""
|
||||
Stores a set of metrics (key/value pairs) to a file logger. That file logger is either one that only holds
|
||||
training or only holds validation metrics.
|
||||
|
||||
:param metrics: A dictionary with all the metrics to write, as key/value pairs.
|
||||
:param epoch: The epoch to which the metrics belong.
|
||||
:param is_training: If true, write the metrics to the logger for training metrics, if False, write to the logger
|
||||
for validation metrics.
|
||||
for validation metrics.
|
||||
"""
|
||||
file_logger = self.train_epoch_metrics_logger if is_training else self.val_epoch_metrics_logger
|
||||
store_epoch_metrics(metrics, epoch, file_logger=file_logger)
|
||||
|
@ -346,6 +349,7 @@ class InnerEyeLightning(LightningModule):
|
|||
"""
|
||||
This hook is called when loading a model from a checkpoint. It just prints out diagnostics about which epoch
|
||||
created the present checkpoint.
|
||||
|
||||
:param checkpoint: The checkpoint dictionary loaded from disk.
|
||||
"""
|
||||
keys = ['epoch', 'global_step']
|
||||
|
@ -371,10 +375,11 @@ class InnerEyeLightning(LightningModule):
|
|||
"""
|
||||
This is the shared method that handles the training (when `is_training==True`) and validation steps
|
||||
(when `is_training==False`)
|
||||
|
||||
:param sample: The minibatch of data that should be processed.
|
||||
:param batch_index: The index of the current minibatch.
|
||||
:param is_training: If true, this has been called from `training_step`, otherwise it has been called from
|
||||
`validation_step`.
|
||||
`validation_step`.
|
||||
"""
|
||||
raise NotImplementedError("This method must be overwritten in a derived class.")
|
||||
|
||||
|
@ -382,9 +387,10 @@ class InnerEyeLightning(LightningModule):
|
|||
"""
|
||||
Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".
|
||||
If this comes from a training step, then also log the learning rate.
|
||||
|
||||
:param loss: The loss value that should be logged.
|
||||
:param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
|
||||
be called "val/Loss"
|
||||
be called "val/Loss"
|
||||
"""
|
||||
assert isinstance(self.trainer, Trainer)
|
||||
self.log_on_epoch(MetricType.LOSS, loss, is_training)
|
||||
|
|
|
@ -31,16 +31,18 @@ class InnerEyeInference(abc.ABC):
|
|||
form of inference is slightly different from what PyTorch Lightning does in its `Trainer.test` method. In
|
||||
particular, this inference can be executed on any of the training, validation, or test set.
|
||||
|
||||
The inference code calls the methods in this order:
|
||||
The inference code calls the methods in this order::
|
||||
|
||||
model.on_inference_start()
|
||||
|
||||
for dataset_split in [Train, Val, Test]
|
||||
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
|
||||
for batch_idx, item in enumerate(dataloader[dataset_split])):
|
||||
model_outputs = model.forward(item)
|
||||
model.inference_step(item, batch_idx, model_outputs)
|
||||
model.on_inference_epoch_end()
|
||||
model.on_inference_end()
|
||||
|
||||
model.on_inference_start()
|
||||
for dataset_split in [Train, Val, Test]
|
||||
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
|
||||
for batch_idx, item in enumerate(dataloader[dataset_split])):
|
||||
model_outputs = model.forward(item)
|
||||
model.inference_step(item, batch_idx, model_outputs)
|
||||
model.on_inference_epoch_end()
|
||||
model.on_inference_end()
|
||||
"""
|
||||
|
||||
def on_inference_start(self) -> None:
|
||||
|
@ -55,9 +57,10 @@ class InnerEyeInference(abc.ABC):
|
|||
Runs initialization for inference, when starting inference on a new dataset split (train/val/test).
|
||||
Depending on the settings, this can be called anywhere between 0 (no inference at all) to 3 times (inference
|
||||
on all of train/val/test split).
|
||||
|
||||
:param dataset_split: Indicates whether the item comes from the training, validation or test set.
|
||||
:param is_ensemble_model: If False, the model_outputs come from an individual model. If True, the model
|
||||
outputs come from multiple models.
|
||||
outputs come from multiple models.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -65,6 +68,7 @@ class InnerEyeInference(abc.ABC):
|
|||
"""
|
||||
This hook is called when the model has finished making a prediction. It can write the results to a file,
|
||||
or compute metrics and store them.
|
||||
|
||||
:param batch: The batch of data for which the model made a prediction.
|
||||
:param model_output: The model outputs. This would usually be a torch.Tensor, but can be any datatype.
|
||||
"""
|
||||
|
@ -91,6 +95,7 @@ class InnerEyeInference(abc.ABC):
|
|||
"""
|
||||
Aggregates the outputs of multiple models when using an ensemble model. In the default implementation,
|
||||
this averages the tensors coming from all the models.
|
||||
|
||||
:param model_outputs: An iterator over the model outputs for all ensemble members.
|
||||
:return: The aggregate model outputs.
|
||||
"""
|
||||
|
@ -183,6 +188,7 @@ class LightningContainer(GenericConfig,
|
|||
Because the method deals with data loaders, not loaded data, we cannot check automatically that cross validation
|
||||
is handled correctly within the base class, i.e. if the cross validation split is not handled in the method then
|
||||
nothing will fail, but each child run will be identical since they will each be given the full dataset.
|
||||
|
||||
:return: A LightningDataModule
|
||||
"""
|
||||
return None # type: ignore
|
||||
|
@ -229,6 +235,7 @@ class LightningContainer(GenericConfig,
|
|||
This can be avoided by always using unique parameter names.
|
||||
Also note that saving a reference to `azure_config` and updating its attributes at any other
|
||||
point may lead to unexpected behaviour.
|
||||
|
||||
:param azure_config: The initialised AzureConfig whose parameters to override in-place.
|
||||
"""
|
||||
pass
|
||||
|
@ -297,6 +304,7 @@ class LightningContainer(GenericConfig,
|
|||
Because this adds a val/Loss metric it is important that when subclassing LightningContainer
|
||||
your implementation of LightningModule logs val/Loss. There is an example of this in
|
||||
HelloRegression's validation_step method.
|
||||
|
||||
:param run_config: The AzureML run configuration object that training for an individual model.
|
||||
:return: A hyperdrive configuration object.
|
||||
"""
|
||||
|
@ -315,6 +323,7 @@ class LightningContainer(GenericConfig,
|
|||
"""
|
||||
Returns the HyperDrive config for either parameter search or cross validation
|
||||
(if number_of_cross_validation_splits > 1).
|
||||
|
||||
:param run_config: AzureML estimator
|
||||
:return: HyperDriveConfigs
|
||||
"""
|
||||
|
|
|
@ -17,6 +17,7 @@ def load_from_lightning_checkpoint(config: ModelConfigBase, checkpoint_path: Pat
|
|||
"""
|
||||
Reads a PyTorch model from a checkpoint. First, a PyTorch Lightning model is created matching the InnerEye
|
||||
model configuration, its parameter tensors are then populated from the given checkpoint.
|
||||
|
||||
:param config: An InnerEye model configuration object
|
||||
:param checkpoint_path: The location of the checkpoint file.
|
||||
:return: A PyTorch Lightning model object.
|
||||
|
@ -37,6 +38,7 @@ def adjust_model_for_inference(config: ModelConfigBase, lightning_model: InnerEy
|
|||
Makes all necessary adjustments to use a given model for inference, possibly on multiple GPUs via
|
||||
model parallelization. The method also computes parameters like output patch size for segmentation model,
|
||||
and stores them in the model configuration.
|
||||
|
||||
:param config: The model configuration object. It may be modified in place.
|
||||
:param lightning_model: The trained model that should be adjusted.
|
||||
"""
|
||||
|
@ -65,6 +67,7 @@ def load_from_checkpoint_and_adjust_for_inference(config: ModelConfigBase, check
|
|||
"""
|
||||
Reads a PyTorch model from a checkpoint, and makes all necessary adjustments to use the model for inference,
|
||||
possibly on multiple GPUs.
|
||||
|
||||
:param config: An InnerEye model configuration object
|
||||
:param checkpoint_path: The location of the checkpoint file.
|
||||
:return: A PyTorch Lightning model object.
|
||||
|
|
|
@ -80,9 +80,10 @@ class StoringLogger(LightningLoggerBase):
|
|||
Reads the set of metrics for a given epoch, filters them to retain only those that have the given prefix,
|
||||
and returns the filtered ones. This is used to break a set
|
||||
of results down into those for training data (prefix "Train/") or validation data (prefix "Val/").
|
||||
|
||||
:param epoch: The epoch for which results should be read.
|
||||
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
:return: A metrics dictionary.
|
||||
"""
|
||||
epoch_results = self.results_per_epoch.get(epoch, None)
|
||||
|
@ -103,8 +104,9 @@ class StoringLogger(LightningLoggerBase):
|
|||
Converts the results stored in the present object into a two-level dictionary, mapping from epoch number to
|
||||
metric name to metric value. Only metrics where the name starts with the given prefix are retained, and the
|
||||
prefix is stripped off in the result.
|
||||
|
||||
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
have a name starting with `prefix`, and strip off the prefix.
|
||||
:return: A dictionary mapping from epoch number to metric name to metric value.
|
||||
"""
|
||||
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}
|
||||
|
@ -113,8 +115,10 @@ class StoringLogger(LightningLoggerBase):
|
|||
"""
|
||||
Gets a scalar metric out of either the list of training or the list of validation results. This returns
|
||||
the value that a specific metric attains in all of the epochs.
|
||||
|
||||
:param is_training: If True, read metrics that have a "train/" prefix, otherwise those that have a "val/"
|
||||
prefix.
|
||||
prefix.
|
||||
|
||||
:param metric_type: The metric to extract.
|
||||
:return: A list of floating point numbers, with one entry per entry in the the training or validation results.
|
||||
"""
|
||||
|
@ -132,6 +136,7 @@ class StoringLogger(LightningLoggerBase):
|
|||
"""
|
||||
Gets a scalar metric from the list of training results. This returns
|
||||
the value that a specific metric attains in all of the epochs.
|
||||
|
||||
:param metric_type: The metric to extract.
|
||||
:return: A list of floating point numbers, with one entry per entry in the the training results.
|
||||
"""
|
||||
|
@ -141,6 +146,7 @@ class StoringLogger(LightningLoggerBase):
|
|||
"""
|
||||
Gets a scalar metric from the list of validation results. This returns
|
||||
the value that a specific metric attains in all of the epochs.
|
||||
|
||||
:param metric_type: The metric to extract.
|
||||
:return: A list of floating point numbers, with one entry per entry in the the validation results.
|
||||
"""
|
||||
|
|
|
@ -20,6 +20,7 @@ def nanmean(values: torch.Tensor) -> torch.Tensor:
|
|||
"""
|
||||
Computes the average of all values in the tensor, skipping those entries that are NaN (not a number).
|
||||
If all values are NaN, the result is also NaN.
|
||||
|
||||
:param values: The values to average.
|
||||
:return: A scalar tensor containing the average.
|
||||
"""
|
||||
|
@ -164,7 +165,7 @@ class ScalarMetricsBase(Metric):
|
|||
difference between true positive rate and false positive rate is smallest. Then, computes
|
||||
the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
|
||||
predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
|
||||
:returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
|
||||
:return: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
|
||||
"""
|
||||
preds, targets = self._get_preds_and_targets()
|
||||
if torch.unique(targets).numel() == 1:
|
||||
|
@ -287,12 +288,14 @@ class MetricForMultipleStructures(torch.nn.Module):
|
|||
use_average_across_structures: bool = True) -> None:
|
||||
"""
|
||||
Creates a new MetricForMultipleStructures object.
|
||||
|
||||
:param ground_truth_ids: The list of anatomical structures that should be stored.
|
||||
:param metric_name: The name of the metric that should be stored. This is used in the names of the individual
|
||||
metrics.
|
||||
metrics.
|
||||
|
||||
:param is_training: If true, use "train/" as the prefix for all metric names, otherwise "val/"
|
||||
:param use_average_across_structures: If True, keep track of the average metric value across structures,
|
||||
while skipping NaNs. If false, only store the per-structure metric values.
|
||||
while skipping NaNs. If false, only store the per-structure metric values.
|
||||
"""
|
||||
super().__init__()
|
||||
prefix = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_name + "/"
|
||||
|
@ -307,6 +310,7 @@ class MetricForMultipleStructures(torch.nn.Module):
|
|||
"""
|
||||
Stores a vector of per-structure Dice scores in the present object. It updates the per-structure values,
|
||||
and the aggregate value across all structures.
|
||||
|
||||
:param values_per_structure: A row tensor that has as many entries as there are ground truth IDs.
|
||||
"""
|
||||
if values_per_structure.dim() != 1 or values_per_structure.numel() != self.count:
|
||||
|
|
|
@ -51,6 +51,7 @@ class SegmentationLightning(InnerEyeLightning):
|
|||
"""
|
||||
Runs a set of 3D crops through the segmentation model, and returns the result. This method is used
|
||||
at inference time.
|
||||
|
||||
:param patches: A tensor of size [batches, channels, Z, Y, X]
|
||||
"""
|
||||
return self.logits_to_posterior(self.model(patches))
|
||||
|
@ -67,8 +68,10 @@ class SegmentationLightning(InnerEyeLightning):
|
|||
is_training: bool) -> torch.Tensor:
|
||||
"""
|
||||
Runs training for a single minibatch of training or validation data, and computes all metrics.
|
||||
|
||||
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
|
||||
`validation_step`.
|
||||
`validation_step`.
|
||||
|
||||
:param sample: The batched sample on which the model should be trained.
|
||||
:param batch_index: The index of the present batch (supplied only for diagnostics).
|
||||
"""
|
||||
|
@ -103,10 +106,11 @@ class SegmentationLightning(InnerEyeLightning):
|
|||
is_training: bool) -> None:
|
||||
"""
|
||||
Computes and stores all metrics coming out of a single training step.
|
||||
|
||||
:param cropped_sample: The batched image crops used for training or validation.
|
||||
:param segmentation: The segmentation that was produced by the model.
|
||||
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
|
||||
`validation_step`.
|
||||
`validation_step`.
|
||||
"""
|
||||
# dice_per_crop_and_class has one row per crop, with background class removed
|
||||
# Dice NaN means that both ground truth and prediction are empty.
|
||||
|
@ -167,6 +171,7 @@ class SegmentationLightning(InnerEyeLightning):
|
|||
def get_subject_output_file_per_rank(rank: int) -> str:
|
||||
"""
|
||||
Gets the name of a file that will store the per-rank per-subject model outputs.
|
||||
|
||||
:param rank: The rank of the current model in distributed training.
|
||||
:return: A string like "rank7_metrics.csv"
|
||||
"""
|
||||
|
@ -228,11 +233,13 @@ class ScalarLightning(InnerEyeLightning):
|
|||
is_training: bool) -> torch.Tensor:
|
||||
"""
|
||||
Runs training for a single minibatch of training or validation data, and computes all metrics.
|
||||
|
||||
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
|
||||
`validation_step`.
|
||||
`validation_step`.
|
||||
|
||||
:param sample: The batched sample on which the model should be trained.
|
||||
:param batch_index: The index of the present batch (supplied only for diagnostics).
|
||||
Runs a minibatch of training or validation data through the model.
|
||||
Runs a minibatch of training or validation data through the model.
|
||||
"""
|
||||
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, sample)
|
||||
labels = model_inputs_and_labels.labels
|
||||
|
@ -283,6 +290,7 @@ class ScalarLightning(InnerEyeLightning):
|
|||
"""
|
||||
For sequence models, transfer the nested lists of items to the given GPU device.
|
||||
For all other models, this relies on the superclass to move the batch of data to the GPU.
|
||||
|
||||
:param batch: A batch of data coming from the dataloader.
|
||||
:param device: The target CUDA device.
|
||||
:return: A modified batch of data, where all tensor now live on the given CUDA device.
|
||||
|
@ -294,6 +302,7 @@ def transfer_batch_to_device(batch: Any, device: torch.device) -> Any:
|
|||
"""
|
||||
For sequence models, transfer the nested lists of items to the given GPU device.
|
||||
For all other models, this relies on Lightning's default code to move the batch of data to the GPU.
|
||||
|
||||
:param batch: A batch of data coming from the dataloader.
|
||||
:param device: The target CUDA device.
|
||||
:return: A modified batch of data, where all tensor now live on the given CUDA device.
|
||||
|
@ -314,8 +323,10 @@ def create_lightning_model(config: ModelConfigBase, set_optimizer_and_scheduler:
|
|||
"""
|
||||
Creates a PyTorch Lightning model that matches the provided InnerEye model configuration object.
|
||||
The `optimizer` and `l_rate_scheduler` object of the Lightning model will also be populated.
|
||||
|
||||
:param set_optimizer_and_scheduler: If True (default), initialize the optimizer and LR scheduler of the model.
|
||||
If False, skip that step (this is only meant to be used for unit tests.)
|
||||
If False, skip that step (this is only meant to be used for unit tests.)
|
||||
|
||||
:param config: An InnerEye model configuration object
|
||||
:return: A PyTorch Lightning model object.
|
||||
"""
|
||||
|
|
|
@ -65,6 +65,7 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
|
|||
def log_metrics(self, run_context: Run = None) -> None:
|
||||
"""
|
||||
Log metrics for each epoch to the provided runs logs, or the current run context if None provided
|
||||
|
||||
:param run_context: Run for which to log the metrics to, use the current run context if None provided
|
||||
:return:
|
||||
"""
|
||||
|
@ -79,6 +80,7 @@ def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> flo
|
|||
"""
|
||||
Symmetric surface distances taking into account the image spacing
|
||||
https://github.com/InsightSoftwareConsortium/SimpleITK-Notebooks/blob/master/Python/34_Segmentation_Evaluation.ipynb
|
||||
|
||||
:param seg: mask 1
|
||||
:param reference_segmentation: mask 2
|
||||
:return: mean distance
|
||||
|
@ -118,6 +120,7 @@ def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> flo
|
|||
def _add_zero_distances(num_segmented_surface_pixels: int, seg2ref_distance_map_arr: np.ndarray) -> List[float]:
|
||||
"""
|
||||
# Get all non-zero distances and then add zero distances if required.
|
||||
|
||||
:param num_segmented_surface_pixels:
|
||||
:param seg2ref_distance_map_arr:
|
||||
:return: list of distances, augmented with zeros.
|
||||
|
@ -137,6 +140,7 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
|
|||
Returns a MetricsDict with metrics for each of the foreground
|
||||
structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
|
||||
If first element of a ground truth image channel is NaN, the image is flagged as NaN and not use.
|
||||
|
||||
:param ground_truth_ids: The names of all foreground classes.
|
||||
:param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
|
||||
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X].
|
||||
|
@ -217,12 +221,13 @@ def compute_dice_across_patches(segmentation: torch.Tensor,
|
|||
allow_multiple_classes_for_each_pixel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Computes the Dice scores for all classes across all patches in the arguments.
|
||||
|
||||
:param segmentation: Tensor containing class ids predicted by a model.
|
||||
:param ground_truth: One-hot encoded torch tensor containing ground-truth label ids.
|
||||
:param allow_multiple_classes_for_each_pixel: If set to False, ground-truth tensor has
|
||||
to contain only one foreground label for each pixel.
|
||||
:return A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for
|
||||
all classes including the background class at index 0.
|
||||
to contain only one foreground label for each pixel.
|
||||
:return: A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for
|
||||
all classes including the background class at index 0.
|
||||
"""
|
||||
check_size_matches(segmentation, ground_truth, 4, 5, [0, -3, -2, -1],
|
||||
arg1_name="segmentation", arg2_name="ground_truth")
|
||||
|
@ -255,6 +260,7 @@ def store_epoch_metrics(metrics: DictStrFloat,
|
|||
"""
|
||||
Writes all metrics (apart from ones that measure run time) into a CSV file,
|
||||
with an additional columns for epoch number.
|
||||
|
||||
:param file_logger: An instance of DataframeLogger, for logging results to csv.
|
||||
:param epoch: The epoch corresponding to the results.
|
||||
:param metrics: The metrics of the specified epoch, averaged along its batches.
|
||||
|
@ -291,12 +297,13 @@ def compute_scalar_metrics(metrics_dict: ScalarMetricsDict,
|
|||
of class 1. The label vector is expected to contain class indices 0 and 1 only.
|
||||
Metrics for each model output channel will be isolated, and a non-default hue for each model output channel is
|
||||
expected, and must exist in the provided metrics_dict. The Default hue is used for single model outputs.
|
||||
|
||||
:param metrics_dict: An object that holds all metrics. It will be updated in-place.
|
||||
:param subject_ids: Subject ids for the model output and labels.
|
||||
:param model_output: A tensor containing model outputs.
|
||||
:param labels: A tensor containing class labels.
|
||||
:param loss_type: The type of loss that the model uses. This is required to optionally convert 2-dim model output
|
||||
to probabilities.
|
||||
to probabilities.
|
||||
"""
|
||||
_model_output_channels = model_output.shape[1]
|
||||
model_output_hues = metrics_dict.get_hue_names(include_default=len(metrics_dict.hues_without_default) == 0)
|
||||
|
@ -346,6 +353,7 @@ def add_average_foreground_dice(metrics: MetricsDict) -> None:
|
|||
"""
|
||||
If the given metrics dictionary contains an entry for Dice score, and only one value for the Dice score per class,
|
||||
then add an average Dice score for all foreground classes to the metrics dictionary (modified in place).
|
||||
|
||||
:param metrics: The object that holds metrics. The average Dice score will be written back into this object.
|
||||
"""
|
||||
all_dice = []
|
||||
|
|
|
@ -35,9 +35,10 @@ def average_metric_values(values: List[float], skip_nan_when_averaging: bool) ->
|
|||
"""
|
||||
Returns the average (arithmetic mean) of the values provided. If skip_nan_when_averaging is True, the mean
|
||||
will be computed without any possible NaN values in the list.
|
||||
|
||||
:param values: The individual values that should be averaged.
|
||||
:param skip_nan_when_averaging: If True, compute mean with any NaN values. If False, any NaN value present
|
||||
in the argument will make the function return NaN.
|
||||
in the argument will make the function return NaN.
|
||||
:return: The average of the provided values. If the argument is an empty list, NaN will be returned.
|
||||
"""
|
||||
if skip_nan_when_averaging:
|
||||
|
@ -61,6 +62,7 @@ def get_column_name_for_logging(metric_name: Union[str, MetricType],
|
|||
"""
|
||||
Computes the column name that should be used when logging a metric to disk.
|
||||
Raises a value error when no column name has yet been defined.
|
||||
|
||||
:param metric_name: The name of the metric.
|
||||
:param hue_name: If provided will be used as a prefix hue_name/column_name
|
||||
"""
|
||||
|
@ -104,11 +106,13 @@ class Hue:
|
|||
labels: np.ndarray) -> None:
|
||||
"""
|
||||
Adds predictions and labels for later computing the area under the ROC curve.
|
||||
|
||||
:param subject_ids: Subject ids associated with the predictions and labels.
|
||||
:param predictions: A numpy array with model predictions, of size [N x C] for N samples in C classes, or size
|
||||
[N x 1] or size [N] for binary.
|
||||
[N x 1] or size [N] for binary.
|
||||
|
||||
:param labels: A numpy array with labels, of size [N x C] for N samples in C classes, or size
|
||||
[N x 1] or size [N] for binary.
|
||||
[N x 1] or size [N] for binary.
|
||||
"""
|
||||
if predictions.ndim == 1:
|
||||
predictions = np.expand_dims(predictions, axis=1)
|
||||
|
@ -155,6 +159,7 @@ class Hue:
|
|||
def _concat_if_needed(arrays: List[np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Joins a list of arrays into a single array, taking empty lists into account correctly.
|
||||
|
||||
:param arrays: Array list to be concatenated.
|
||||
"""
|
||||
if arrays:
|
||||
|
@ -192,7 +197,8 @@ class MetricsDict:
|
|||
def __init__(self, hues: Optional[List[str]] = None, is_classification_metrics: bool = True) -> None:
|
||||
"""
|
||||
:param hues: Supported hues for this metrics dict, otherwise all records will belong to the
|
||||
default hue.
|
||||
default hue.
|
||||
|
||||
:param is_classification_metrics: If this is a classification metrics dict
|
||||
"""
|
||||
|
||||
|
@ -210,14 +216,16 @@ class MetricsDict:
|
|||
def subject_ids(self, hue: str = DEFAULT_HUE_KEY) -> List[str]:
|
||||
"""
|
||||
Return the subject ids that have metrics associated with them in this dictionary.
|
||||
|
||||
:param hue: If provided then subject ids belonging to this hue only will be returned.
|
||||
Otherwise subject ids for the default hue will be returned.
|
||||
Otherwise subject ids for the default hue will be returned.
|
||||
"""
|
||||
return self._get_hue(hue=hue).subject_ids
|
||||
|
||||
def get_hue_names(self, include_default: bool = True) -> List[str]:
|
||||
"""
|
||||
Returns all of the hues supported by this metrics dict
|
||||
|
||||
:param include_default: Include the default hue if True, otherwise exclude the default hue.
|
||||
"""
|
||||
_hue_names = list(self.hues.keys())
|
||||
|
@ -228,6 +236,7 @@ class MetricsDict:
|
|||
def delete_hue(self, hue: str) -> None:
|
||||
"""
|
||||
Removes all data stored for the given hue from the present object.
|
||||
|
||||
:param hue: The hue to remove.
|
||||
"""
|
||||
del self.hues[hue]
|
||||
|
@ -236,6 +245,7 @@ class MetricsDict:
|
|||
"""
|
||||
Gets the value stored for the given metric. The method assumes that there is a single value stored for the
|
||||
metric, and raises a ValueError if that is not the case.
|
||||
|
||||
:param metric_name: The name of the metric to retrieve.
|
||||
:param hue: The hue to retrieve the metric from.
|
||||
:return:
|
||||
|
@ -249,6 +259,7 @@ class MetricsDict:
|
|||
def has_prediction_entries(self, hue: str = DEFAULT_HUE_KEY) -> bool:
|
||||
"""
|
||||
Returns True if the present object stores any entries for computing the Area Under Roc Curve metric.
|
||||
|
||||
:param hue: will be used to check a particular hue otherwise default hue will be used.
|
||||
:return: True if entries exist. False otherwise.
|
||||
"""
|
||||
|
@ -257,8 +268,9 @@ class MetricsDict:
|
|||
def values(self, hue: str = DEFAULT_HUE_KEY) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns values held currently in the dict
|
||||
|
||||
:param hue: will be used to restrict values for the provided hue otherwise values in the default
|
||||
hue will be returned.
|
||||
hue will be returned.
|
||||
:return: Dictionary of values for this object.
|
||||
"""
|
||||
return self._get_hue(hue).values
|
||||
|
@ -267,6 +279,7 @@ class MetricsDict:
|
|||
"""
|
||||
Adds a diagnostic value to the present object. Multiple diagnostics can be stored per unique value of name,
|
||||
the values get concatenated.
|
||||
|
||||
:param name: The name of the diagnostic value to store.
|
||||
:param value: The value to store.
|
||||
"""
|
||||
|
@ -292,10 +305,12 @@ class MetricsDict:
|
|||
hue: str = DEFAULT_HUE_KEY) -> None:
|
||||
"""
|
||||
Adds values for a single metric to the present object, when the metric value is a scalar.
|
||||
|
||||
:param metric_name: The name of the metric to add. This can be a string or a value in the MetricType enum.
|
||||
:param metric_value: The values of the metric, as a float or integer.
|
||||
:param skip_nan_when_averaging: If True, averaging this metric will skip any NaN (not a number) values.
|
||||
If False, NaN will propagate through the mean computation.
|
||||
If False, NaN will propagate through the mean computation.
|
||||
|
||||
:param hue: The hue for which this record belongs to, default hue will be used if None provided.
|
||||
"""
|
||||
_metric_name = MetricsDict._metric_name(metric_name)
|
||||
|
@ -315,6 +330,7 @@ class MetricsDict:
|
|||
hue: str = DEFAULT_HUE_KEY) -> None:
|
||||
"""
|
||||
Deletes all values that are stored for a given metric from the present object.
|
||||
|
||||
:param metric_name: The name of the metric to add. This can be a string or a value in the MetricType enum.
|
||||
:param hue: The hue for which this record belongs to, default hue will be used if None provided.
|
||||
"""
|
||||
|
@ -327,11 +343,14 @@ class MetricsDict:
|
|||
hue: str = DEFAULT_HUE_KEY) -> None:
|
||||
"""
|
||||
Adds predictions and labels for later computing the area under the ROC curve.
|
||||
|
||||
:param subject_ids: Subject ids associated with the predictions and labels.
|
||||
:param predictions: A numpy array with model predictions, of size [N x C] for N samples in C classes, or size
|
||||
[N x 1] or size [N] for binary.
|
||||
[N x 1] or size [N] for binary.
|
||||
|
||||
:param labels: A numpy array with labels, of size [N x C] for N samples in C classes, or size
|
||||
[N x 1] or size [N] for binary.
|
||||
[N x 1] or size [N] for binary.
|
||||
|
||||
:param hue: The hue this prediction belongs to, default hue will be used if None provided.
|
||||
"""
|
||||
self._get_hue(hue).add_predictions(subject_ids=subject_ids,
|
||||
|
@ -341,6 +360,7 @@ class MetricsDict:
|
|||
def num_entries(self, hue: str = DEFAULT_HUE_KEY) -> Dict[str, int]:
|
||||
"""
|
||||
Gets the number of values that are stored for each individual metric.
|
||||
|
||||
:param hue: The hue to count entries for, otherwise all entries will be counted.
|
||||
:return: A dictionary mapping from metric name to number of values stored.
|
||||
"""
|
||||
|
@ -355,9 +375,10 @@ class MetricsDict:
|
|||
object.
|
||||
Computing the average will respect the skip_nan_when_averaging value that has been provided when adding
|
||||
the metric.
|
||||
|
||||
:param add_metrics_from_entries: average existing metrics in the dict.
|
||||
:param across_hues: If True then same metric types will be averaged regardless of hues, otherwise
|
||||
separate averages for each metric type for each hue will be computed, Default is True.
|
||||
separate averages for each metric type for each hue will be computed, Default is True.
|
||||
:return: A MetricsDict object with a single-item list for each of the metrics.
|
||||
"""
|
||||
|
||||
|
@ -434,8 +455,9 @@ class MetricsDict:
|
|||
difference between true positive rate and false positive rate is smallest. Then, computes
|
||||
the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
|
||||
predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
|
||||
:return: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
|
||||
"""
|
||||
fpr, tpr, thresholds = roc_curve(self.get_labels(hue=hue), self.get_predictions(hue=hue))
|
||||
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
|
||||
|
@ -450,6 +472,7 @@ class MetricsDict:
|
|||
def get_roc_auc(self, hue: str = DEFAULT_HUE_KEY) -> float:
|
||||
"""
|
||||
Computes the Area Under the ROC curve, from the entries that were supplied in the add_roc_entries method.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: The AUC score, or np.nan if no entries are available in the present object.
|
||||
"""
|
||||
|
@ -469,6 +492,7 @@ class MetricsDict:
|
|||
"""
|
||||
Computes the Area Under the Precision Recall Curve, from the entries that were supplied in the
|
||||
add_roc_entries method.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: The PR AUC score, or np.nan if no entries are available in the present object.
|
||||
"""
|
||||
|
@ -488,6 +512,7 @@ class MetricsDict:
|
|||
"""
|
||||
Computes the binary cross entropy from the entries that were supplied in the
|
||||
add_roc_entries method.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: The cross entropy score.
|
||||
"""
|
||||
|
@ -498,6 +523,7 @@ class MetricsDict:
|
|||
def get_mean_absolute_error(self, hue: str = DEFAULT_HUE_KEY) -> float:
|
||||
"""
|
||||
Get the mean absolute error.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: Mean absolute error.
|
||||
"""
|
||||
|
@ -506,6 +532,7 @@ class MetricsDict:
|
|||
def get_mean_squared_error(self, hue: str = DEFAULT_HUE_KEY) -> float:
|
||||
"""
|
||||
Get the mean squared error.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: Mean squared error
|
||||
"""
|
||||
|
@ -514,6 +541,7 @@ class MetricsDict:
|
|||
def get_r2_score(self, hue: str = DEFAULT_HUE_KEY) -> float:
|
||||
"""
|
||||
Get the R2 score.
|
||||
|
||||
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
|
||||
:return: R2 score
|
||||
"""
|
||||
|
@ -524,6 +552,7 @@ class MetricsDict:
|
|||
Returns an iterator that contains all (hue name, metric name, metric values) tuples that are stored in the
|
||||
present object. This method assumes that for each hue/metric combination there is exactly 1 value, and it
|
||||
throws an exception if that is more than 1 value.
|
||||
|
||||
:param hue: The hue to restrict the values, otherwise all values will be used if set to None.
|
||||
:return: An iterator with (hue name, metric name, metric values) pairs.
|
||||
"""
|
||||
|
@ -536,6 +565,7 @@ class MetricsDict:
|
|||
"""
|
||||
Returns an iterator that contains all (hue name, metric name, metric values) tuples that are stored in the
|
||||
present object.
|
||||
|
||||
:param hue: The hue to restrict the values, otherwise all values will be used if set to None.
|
||||
:param ensure_singleton_values_only: Ensure that each of the values return is a singleton.
|
||||
:return: An iterator with (hue name, metric name, metric values) pairs.
|
||||
|
@ -566,6 +596,7 @@ class MetricsDict:
|
|||
def get_predictions(self, hue: str = DEFAULT_HUE_KEY) -> np.ndarray:
|
||||
"""
|
||||
Return a concatenated copy of the roc predictions stored internally.
|
||||
|
||||
:param hue: The hue to restrict the values, otherwise all values will be used.
|
||||
:return: concatenated roc predictions as np array
|
||||
"""
|
||||
|
@ -574,6 +605,7 @@ class MetricsDict:
|
|||
def get_labels(self, hue: str = DEFAULT_HUE_KEY) -> np.ndarray:
|
||||
"""
|
||||
Return a concatenated copy of the roc labels stored internally.
|
||||
|
||||
:param hue: The hue to restrict the values, otherwise all values will be used.
|
||||
:return: roc labels as np array
|
||||
"""
|
||||
|
@ -583,6 +615,7 @@ class MetricsDict:
|
|||
-> List[PredictionEntry[float]]:
|
||||
"""
|
||||
Gets the per-subject labels and predictions that are stored in the present object.
|
||||
|
||||
:param hue: The hue to restrict the values, otherwise the default hue will be used.
|
||||
:return: List of per-subject labels and predictions
|
||||
"""
|
||||
|
@ -591,6 +624,7 @@ class MetricsDict:
|
|||
def to_string(self, tabulate: bool = True) -> str:
|
||||
"""
|
||||
Creates a multi-line human readable string from the given metrics.
|
||||
|
||||
:param tabulate: If True then create a pretty printable table string.
|
||||
:return: Formatted metrics string
|
||||
"""
|
||||
|
@ -623,6 +657,7 @@ class MetricsDict:
|
|||
"""
|
||||
Get the hue record for the provided key.
|
||||
Raises a KeyError if the provided hue key does not exist.
|
||||
|
||||
:param hue: The hue to retrieve record for
|
||||
"""
|
||||
if hue not in self.hues:
|
||||
|
@ -654,6 +689,7 @@ class ScalarMetricsDict(MetricsDict):
|
|||
cross_validation_split_index: int = DEFAULT_CROSS_VALIDATION_SPLIT_INDEX) -> None:
|
||||
"""
|
||||
Store metrics using the provided df_logger at subject level for classification models.
|
||||
|
||||
:param df_logger: A data frame logger to use to write the metrics to disk.
|
||||
:param mode: Model execution mode these metrics belong to.
|
||||
:param cross_validation_split_index: cross validation split index for the epoch if performing cross val
|
||||
|
@ -677,8 +713,9 @@ class ScalarMetricsDict(MetricsDict):
|
|||
"""
|
||||
Helper function to create BinaryClassificationMetricsDict grouped by ModelExecutionMode and epoch
|
||||
from a given dataframe. The following columns must exist in the provided data frame:
|
||||
>>> LoggingColumns.DataSplit
|
||||
>>> LoggingColumns.Epoch
|
||||
|
||||
* LoggingColumns.DataSplit
|
||||
* LoggingColumns.Epoch
|
||||
|
||||
:param df: DataFrame to use for creating the metrics dict.
|
||||
:param is_classification_metrics: If the current metrics are for classification or not.
|
||||
|
@ -720,6 +757,7 @@ class ScalarMetricsDict(MetricsDict):
|
|||
Given metrics dicts for execution modes and epochs, compute the aggregate metrics that are computed
|
||||
from the per-subject predictions. The metrics are written to the dataframe logger with the string labels
|
||||
(column names) taken from the `MetricType` enum.
|
||||
|
||||
:param metrics: Mapping between epoch and subject level metrics
|
||||
:param data_frame_logger: DataFrame logger to write to and flush
|
||||
:param log_info: If True then log results as an INFO string to the default logger also.
|
||||
|
@ -781,6 +819,7 @@ class SequenceMetricsDict(ScalarMetricsDict):
|
|||
"""
|
||||
Extracts a sequence target index from a metrics hue name. For example, from metrics hue "Seq_pos 07",
|
||||
it would return 7.
|
||||
|
||||
:param hue_name: hue name containing sequence target index
|
||||
"""
|
||||
if hue_name.startswith(SEQUENCE_POSITION_HUE_NAME_PREFIX):
|
||||
|
@ -807,6 +846,7 @@ class DataframeLogger:
|
|||
def flush(self, log_info: bool = False) -> None:
|
||||
"""
|
||||
Save the internal records to a csv file.
|
||||
|
||||
:param log_info: If true, write the final dataframe also to logging.info.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
|
|
@ -55,6 +55,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
Returns a configuration for AzureML Hyperdrive that should be used when running hyperparameter
|
||||
tuning.
|
||||
This is an abstract method that each specific model should override.
|
||||
|
||||
:param run_config: The AzureML estimator object that runs model training.
|
||||
:return: A hyperdrive configuration object.
|
||||
"""
|
||||
|
@ -66,6 +67,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
"""
|
||||
Computes the training, validation and test splits for the model, from a dataframe that contains
|
||||
the full dataset.
|
||||
|
||||
:param dataset_df: A dataframe that contains the full dataset that the model is using.
|
||||
:return: An instance of DatasetSplits with dataframes for training, validation and testing.
|
||||
"""
|
||||
|
@ -83,6 +85,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
are False, the derived method *may* still create the corresponding datasets, but should not assume that
|
||||
the relevant splits (train/test/val) are non-empty. If either or both is True, they *must* create the
|
||||
corresponding datasets, and should be able to make the assumption.
|
||||
|
||||
:param for_training: whether to create the datasets required for training.
|
||||
:param for_inference: whether to create the datasets required for inference.
|
||||
"""
|
||||
|
@ -103,6 +106,8 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
"""
|
||||
Returns a torch Dataset for running the model in inference mode, on the given split of the full dataset.
|
||||
The torch dataset must return data in the format required for running the model in inference mode.
|
||||
|
||||
:param mode: The mode of the model, either test, train or val.
|
||||
:return: A torch Dataset object.
|
||||
"""
|
||||
if self._datasets_for_inference is None:
|
||||
|
@ -114,7 +119,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
"""
|
||||
Creates the torch DataLoaders that supply the training and the validation set during training only.
|
||||
:return: A dictionary, with keys ModelExecutionMode.TRAIN and ModelExecutionMode.VAL, and their respective
|
||||
data loaders.
|
||||
data loaders.
|
||||
"""
|
||||
logging.info("Starting to read and parse the datasets.")
|
||||
if self._datasets_for_training is None:
|
||||
|
@ -161,6 +166,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
|
||||
"""
|
||||
Returns a configuration for AzureML Hyperdrive that varies the cross validation split index.
|
||||
|
||||
:param run_config: The AzureML run configuration object that training for an individual model.
|
||||
:return: A hyperdrive configuration object.
|
||||
"""
|
||||
|
@ -176,9 +182,10 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
"""
|
||||
When running cross validation, this method returns the dataset split that should be used for the
|
||||
currently executed cross validation split.
|
||||
|
||||
:param dataset_split: The full dataset, split into training, validation and test section.
|
||||
:return: The dataset split with training and validation sections shuffled according to the current
|
||||
cross validation index.
|
||||
cross validation index.
|
||||
"""
|
||||
splits = dataset_split.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)
|
||||
return splits[self.cross_validation_split_index]
|
||||
|
@ -187,6 +194,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
"""
|
||||
Returns the HyperDrive config for either parameter search or cross validation
|
||||
(if number_of_cross_validation_splits > 1).
|
||||
|
||||
:param run_config: AzureML estimator
|
||||
:return: HyperDriveConfigs
|
||||
"""
|
||||
|
@ -239,6 +247,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
|
|||
A hook to adjust the model configuration that is stored in the present object to match
|
||||
the torch model given in the argument. This hook is called after adjusting the model for
|
||||
mixed precision and parallel training.
|
||||
|
||||
:param model: The torch model.
|
||||
"""
|
||||
pass
|
||||
|
@ -267,7 +276,8 @@ class ModelTransformsPerExecutionMode:
|
|||
"""
|
||||
|
||||
:param train: the transformation(s) to apply to the training set.
|
||||
Should be a function that takes a sample as input and outputs sample.
|
||||
Should be a function that takes a sample as input and outputs sample.
|
||||
|
||||
:param val: the transformation(s) to apply to the validation set
|
||||
:param test: the transformation(s) to apply to the test set
|
||||
"""
|
||||
|
|
|
@ -51,6 +51,7 @@ def model_test(config: ModelConfigBase,
|
|||
Runs model inference on segmentation or classification models, using a given dataset (that could be training,
|
||||
test or validation set). The inference results and metrics will be stored and logged in a way that may
|
||||
differ for model categories (classification, segmentation).
|
||||
|
||||
:param config: The configuration of the model
|
||||
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model.
|
||||
|
@ -82,6 +83,7 @@ def segmentation_model_test(config: SegmentationModelBase,
|
|||
"""
|
||||
The main testing loop for segmentation models.
|
||||
It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
|
||||
|
||||
:param config: The arguments object which has a valid random seed attribute.
|
||||
:param execution_mode: Indicates which of the 3 sets (training, test, or validation) is being processed.
|
||||
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization.
|
||||
|
@ -121,6 +123,7 @@ def segmentation_model_test_epoch(config: SegmentationModelBase,
|
|||
The main testing loop for a given epoch. It loads the model and datasets, then proceeds to test the model.
|
||||
Returns a list with an entry for each image in the dataset. The entry is the average Dice score,
|
||||
where the average is taken across all non-background structures in the image.
|
||||
|
||||
:param checkpoint_paths: Checkpoint paths to run inference on.
|
||||
:param config: The arguments which specify all required information.
|
||||
:param execution_mode: Is the model evaluated on train, test, or validation set?
|
||||
|
@ -128,7 +131,7 @@ def segmentation_model_test_epoch(config: SegmentationModelBase,
|
|||
:param epoch_and_split: A string that should uniquely identify the epoch and the data split (train/val/test).
|
||||
:raises TypeError: If the arguments are of the wrong type.
|
||||
:raises ValueError: When there are issues loading the model.
|
||||
:return A list with the mean dice score (across all structures apart from background) for each image.
|
||||
:return: A list with the mean dice score (across all structures apart from background) for each image.
|
||||
"""
|
||||
ml_util.set_random_seed(config.get_effective_random_seed(), "Model testing")
|
||||
results_folder.mkdir(exist_ok=True)
|
||||
|
@ -204,11 +207,12 @@ def evaluate_model_predictions(process_id: int,
|
|||
Evaluates model segmentation predictions, dice scores and surface distances are computed.
|
||||
Generated contours are plotted and saved in results folder.
|
||||
The function is intended to be used in parallel for loop to process each image in parallel.
|
||||
|
||||
:param process_id: Identifier for the process calling the function
|
||||
:param config: Segmentation model config object
|
||||
:param dataset: Dataset object, it is used to load intensity image, labels, and patient metadata.
|
||||
:param results_folder: Path to results folder
|
||||
:returns [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
|
||||
:return: [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
|
||||
"""
|
||||
sample = dataset.get_samples_at_index(index=process_id)[0]
|
||||
logging.info(f"Evaluating predictions for patient {sample.patient_id}")
|
||||
|
@ -235,10 +239,12 @@ def populate_metrics_writer(
|
|||
config: SegmentationModelBase) -> Tuple[MetricsPerPatientWriter, List[FloatOrInt]]:
|
||||
"""
|
||||
Populate a MetricsPerPatientWriter with the metrics for each patient
|
||||
|
||||
:param model_prediction_evaluations: The list of PatientMetadata/MetricsDict tuples obtained
|
||||
from evaluate_model_predictions
|
||||
from evaluate_model_predictions
|
||||
|
||||
:param config: The SegmentationModelBase config from which we read the ground_truth_ids
|
||||
:returns: A new MetricsPerPatientWriter and a list of foreground DICE score averages
|
||||
:return: A new MetricsPerPatientWriter and a list of foreground DICE score averages
|
||||
"""
|
||||
average_dice: List[FloatOrInt] = []
|
||||
metrics_writer = MetricsPerPatientWriter()
|
||||
|
@ -263,6 +269,7 @@ def get_patient_results_folder(results_folder: Path, patient_id: int) -> Path:
|
|||
"""
|
||||
Gets a folder name that will contain all results for a given patient, like root/017 for patient 17.
|
||||
The folder name is constructed such that string sorting gives numeric sorting.
|
||||
|
||||
:param results_folder: The root folder in which the per-patient results should sit.
|
||||
:param patient_id: The numeric ID of the patient.
|
||||
:return: A path like "root/017"
|
||||
|
@ -276,8 +283,10 @@ def store_inference_results(inference_result: InferencePipeline.Result,
|
|||
image_header: ImageHeader) -> List[Path]:
|
||||
"""
|
||||
Store the segmentation, posteriors, and binary predictions into Nifti files.
|
||||
|
||||
:param inference_result: The inference result for a given patient_id and epoch. Posteriors must be in
|
||||
(Classes x Z x Y x X) shape, segmentation in (Z, Y, X)
|
||||
(Classes x Z x Y x X) shape, segmentation in (Z, Y, X)
|
||||
|
||||
:param config: The test configurations.
|
||||
:param results_folder: The folder where the prediction should be stored.
|
||||
:param image_header: The image header that was used in the input image.
|
||||
|
@ -286,6 +295,7 @@ def store_inference_results(inference_result: InferencePipeline.Result,
|
|||
def create_file_path(_results_folder: Path, _file_name: str) -> Path:
|
||||
"""
|
||||
Create filename with Nifti extension
|
||||
|
||||
:param _results_folder: The results folder
|
||||
:param _file_name: The name of the file
|
||||
:return: A full path to the results folder for the file
|
||||
|
@ -339,6 +349,7 @@ def store_run_information(results_folder: Path,
|
|||
image_channels: List[str]) -> None:
|
||||
"""
|
||||
Store dataset id and ground truth ids into files in the results folder.
|
||||
|
||||
:param image_channels: The names of the image channels that the model consumes.
|
||||
:param results_folder: The folder where the files should be stored.
|
||||
:param dataset_id: The dataset id
|
||||
|
@ -359,6 +370,7 @@ def create_inference_pipeline(config: ModelConfigBase,
|
|||
"""
|
||||
If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline.
|
||||
If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned.
|
||||
|
||||
:param config: Model related configs.
|
||||
:param epoch: The epoch for which to create pipeline for.
|
||||
:param run_recovery: RunRecovery data if applicable
|
||||
|
@ -408,9 +420,11 @@ def classification_model_test(config: ScalarModelBase,
|
|||
"""
|
||||
The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
|
||||
It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
|
||||
|
||||
:param config: The model configuration.
|
||||
:param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
|
||||
used mainly in model evaluation using different dataset splits.
|
||||
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model
|
||||
:param model_proc: whether we are testing an ensemble or single model
|
||||
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.
|
||||
|
|
|
@ -38,6 +38,7 @@ def upload_output_file_as_temp(file_path: Path, outputs_folder: Path) -> None:
|
|||
"""
|
||||
Uploads a file to the AzureML run. It will get a name that is composed of a "temp/" prefix, plus the path
|
||||
of the file relative to the outputs folder that is used for training.
|
||||
|
||||
:param file_path: The path of the file to upload.
|
||||
:param outputs_folder: The root folder that contains all training outputs.
|
||||
"""
|
||||
|
@ -65,6 +66,7 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
|
||||
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
|
||||
return value.
|
||||
|
||||
:param container: The container with model and data.
|
||||
:param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
|
||||
:param num_nodes: The number of nodes to use in distributed training.
|
||||
|
@ -204,13 +206,14 @@ def model_train(checkpoint_path: Optional[Path],
|
|||
The main training loop. It creates the Pytorch model based on the configuration options passed in,
|
||||
creates a Pytorch Lightning trainer, and trains the model.
|
||||
If a checkpoint was specified, then it loads the checkpoint before resuming training.
|
||||
|
||||
:param checkpoint_path: Checkpoint path for model initialization
|
||||
:param num_nodes: The number of nodes to use in distributed training.
|
||||
:param container: A container object that holds the training data in PyTorch Lightning format
|
||||
and the model to train.
|
||||
and the model to train.
|
||||
:return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
|
||||
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
|
||||
fitting other models.
|
||||
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
|
||||
fitting other models.
|
||||
"""
|
||||
lightning_model = container.model
|
||||
|
||||
|
@ -323,8 +326,9 @@ def model_train(checkpoint_path: Optional[Path],
|
|||
def aggregate_and_create_subject_metrics_file(outputs_folder: Path) -> None:
|
||||
"""
|
||||
This functions takes all the subject metrics file written by each GPU (one file per GPU) and aggregates them into
|
||||
one single metrics file. Results is saved in config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME.
|
||||
one single metrics file. Results is saved in ``config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME``.
|
||||
This is done for the metrics files for training and for validation data separately.
|
||||
|
||||
:param config: model config
|
||||
"""
|
||||
for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:
|
||||
|
|
|
@ -23,8 +23,9 @@ class CropSizeConstraints:
|
|||
"""
|
||||
:param multiple_of: Stores minimum size and other conditions that a training crop size must satisfy.
|
||||
:param minimum_size: Training crops must have a size that is a multiple of this value, along each dimension.
|
||||
For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a
|
||||
For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a
|
||||
multiple of 1 (i.e., any number) along the Z dimension.
|
||||
|
||||
:param num_dimensions: Training crops must have a size that is at least this value.
|
||||
"""
|
||||
self.multiple_of = multiple_of
|
||||
|
@ -56,6 +57,7 @@ class CropSizeConstraints:
|
|||
"""
|
||||
Checks if the given crop size is a valid crop size for the present model.
|
||||
If it is not valid, throw a ValueError.
|
||||
|
||||
:param crop_size: The crop size that should be checked.
|
||||
:param message_prefix: A string prefix for the error message if the crop size is found to be invalid.
|
||||
:return:
|
||||
|
@ -83,6 +85,7 @@ class CropSizeConstraints:
|
|||
(at test time). The new crop size will be the largest multiple of self.multiple_of that fits into the
|
||||
image_shape.
|
||||
The stride size will attempt to maintain the stride-to-crop ratio before adjustment.
|
||||
|
||||
:param image_shape: The shape of the image to process.
|
||||
:param crop_size: The present test crop size.
|
||||
:param stride_size: The present inference stride size.
|
||||
|
@ -121,10 +124,11 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
|
|||
):
|
||||
"""
|
||||
Creates a new instance of the base model class.
|
||||
|
||||
:param name: A human readable name of the model.
|
||||
:param input_channels: The number of image input channels.
|
||||
:param crop_size_constraints: The size constraints for the training crop size. If not provided,
|
||||
a minimum crop size of 1 is assumed.
|
||||
a minimum crop size of 1 is assumed.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_dimensions = 3
|
||||
|
@ -144,6 +148,7 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
|
|||
The argument is expected to be either a 2-tuple or a 3-tuple. A batch dimension (1)
|
||||
and the number of channels are added as the first dimensions. The result tuple has batch and channel dimension
|
||||
stripped off.
|
||||
|
||||
:param input_shape: A tuple (2D or 3D) representing incoming tensor shape.
|
||||
"""
|
||||
# Create a sample tensor for inference
|
||||
|
@ -166,6 +171,7 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
|
|||
"""
|
||||
Checks if the given crop size is a valid crop size for the present model.
|
||||
If it is not valid, throw a ValueError.
|
||||
|
||||
:param crop_size: The crop size that should be checked.
|
||||
:param message_prefix: A string prefix for the error message if the crop size is found to be invalid.
|
||||
"""
|
||||
|
@ -178,8 +184,9 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
|
|||
Stores a model summary, containing information about layers, memory consumption and runtime
|
||||
in the model.summary field.
|
||||
When called again with the same crop_size, the summary is not created again.
|
||||
|
||||
:param crop_size: The crop size for which the summary should be created. If not provided,
|
||||
the minimum allowed crop size is used.
|
||||
the minimum allowed crop size is used.
|
||||
:param log_summaries_to_files: whether to write the summary to a file
|
||||
"""
|
||||
if crop_size is None:
|
||||
|
|
|
@ -64,23 +64,29 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
"""
|
||||
Creates an image classifier that has UNet encoders sections for each image channel. The encoder output
|
||||
is fed through average pooling and an MLP.
|
||||
|
||||
:param encode_channels_jointly: If False, create a UNet encoder structure separately for each channel. If True,
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
|
||||
:param num_encoder_blocks: Number of UNet encoder blocks.
|
||||
:param initial_feature_channels: Number of feature channels in the first UNet encoder.
|
||||
:param num_image_channels: Number of channels of the input. Input is expected to be of size
|
||||
B x num_image_channels x Z x Y x X, where B is the batch dimension.
|
||||
B x num_image_channels x Z x Y x X, where B is the batch dimension.
|
||||
|
||||
:param num_non_image_features: Number of non imaging features will be used in the model.
|
||||
:param kernel_size_per_encoding_block: The size of the kernels per encoding block, assumed to be the same
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
performs convolutions only in X and Y.
|
||||
|
||||
:param stride_size_per_encoding_block: The stride size for the encoding block, assumed to be the same
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
reduces spatial dimensions only in X and Y.
|
||||
|
||||
:param encoder_dimensionality_reduction_factor: how to reduce the dimensionality of the image features in the
|
||||
combined model to balance with non imaging features.
|
||||
combined model to balance with non imaging features.
|
||||
|
||||
:param scan_size: should be a tuple representing 3D tensor shape and if specified it's usedd in initializing
|
||||
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
|
||||
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_non_image_features = num_non_image_features
|
||||
|
@ -168,6 +174,7 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
def _get_aggregation_layer(self, aggregation_type: AggregationType, scan_size: Optional[TupleInt3]) -> Any:
|
||||
"""
|
||||
Returns the aggregation layer as specified by the config
|
||||
|
||||
:param aggregation_type: name of the aggregation
|
||||
:param scan_size: [Z, Y, X] size of the scans
|
||||
"""
|
||||
|
@ -191,6 +198,7 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
|
||||
"""
|
||||
Transforms a classification item into a torch.Tensor that the forward pass can consume
|
||||
|
||||
:param item: ClassificationItem
|
||||
:return: Tensor
|
||||
"""
|
||||
|
@ -290,25 +298,31 @@ class ImageEncoderWithMlp(ImageEncoder):
|
|||
Creates an image classifier that has UNet encoders sections for each image channel. The encoder output
|
||||
is fed through average pooling and an MLP. Extension of the ImageEncoder class using an MLP as classification
|
||||
layer.
|
||||
|
||||
:param encode_channels_jointly: If False, create a UNet encoder structure separately for each channel. If True,
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
|
||||
:param num_encoder_blocks: Number of UNet encoder blocks.
|
||||
:param initial_feature_channels: Number of feature channels in the first UNet encoder.
|
||||
:param num_image_channels: Number of channels of the input. Input is expected to be of size
|
||||
B x num_image_channels x Z x Y x X, where B is the batch dimension.
|
||||
B x num_image_channels x Z x Y x X, where B is the batch dimension.
|
||||
|
||||
:param mlp_dropout: The amount of dropout that should be applied between the two layers of the classifier MLP.
|
||||
:param final_activation: Activation function to normalize the logits default is Identity.
|
||||
:param num_non_image_features: Number of non imaging features will be used in the model.
|
||||
:param kernel_size_per_encoding_block: The size of the kernels per encoding block, assumed to be the same
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
performs convolutions only in X and Y.
|
||||
|
||||
:param stride_size_per_encoding_block: The stride size for the encoding block, assumed to be the same
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
|
||||
reduces spatial dimensions only in X and Y.
|
||||
|
||||
:param encoder_dimensionality_reduction_factor: how to reduce the dimensionality of the image features in the
|
||||
combined model to balance with non imaging features.
|
||||
combined model to balance with non imaging features.
|
||||
|
||||
:param scan_size: should be a tuple representing 3D tensor shape and if specified it's usedd in initializing
|
||||
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
|
||||
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
|
||||
"""
|
||||
super().__init__(imaging_feature_type=imaging_feature_type,
|
||||
encode_channels_jointly=encode_channels_jointly,
|
||||
|
@ -369,12 +383,13 @@ def create_mlp(input_num_feature_channels: int,
|
|||
hidden_layer_num_feature_channels: Optional[int] = None) -> MLP:
|
||||
"""
|
||||
Create an MLP with 1 hidden layer.
|
||||
|
||||
:param input_num_feature_channels: The number of input channels to the first MLP layer.
|
||||
:param dropout: The drop out factor that should be applied between the first and second MLP layer.
|
||||
:param final_output_channels: if provided, the final number of output channels.
|
||||
:param final_layer: if provided, the final (activation) layer to apply
|
||||
:param hidden_layer_num_feature_channels: if provided, will be used to create hidden layers, If None then
|
||||
input_num_feature_channels // 2 will be used to create the hidden layer.
|
||||
input_num_feature_channels // 2 will be used to create the hidden layer.
|
||||
:return:
|
||||
"""
|
||||
layers: List[torch.nn.Module] = []
|
||||
|
|
|
@ -160,11 +160,12 @@ class MultiSegmentationEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
) -> None:
|
||||
"""
|
||||
:param encode_channels_jointly: If False, create an encoder structure separately for each channel. If True,
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
encode all channels jointly (convolution will run over all channels).
|
||||
|
||||
:param num_image_channels: Number of channels of the input. Input is expected to be of size
|
||||
B x (num_image_channels * 10) x Z x Y x X, where B is the batch dimension.
|
||||
B x (num_image_channels * 10) x Z x Y x X, where B is the batch dimension.
|
||||
:param use_mixed_precision: If True, assume that training happens with mixed precision. Segmentations will
|
||||
be converted to float16 tensors right away. If False, segmentations will be converted to float32 tensors.
|
||||
be converted to float16 tensors right away. If False, segmentations will be converted to float32 tensors.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder_input_channels = \
|
||||
|
@ -194,6 +195,7 @@ class MultiSegmentationEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
|
||||
"""
|
||||
Transforms a classification item into a torch.Tensor that the forward pass can consume
|
||||
|
||||
:param item: ClassificationItem
|
||||
:return: Tensor
|
||||
"""
|
||||
|
|
|
@ -27,9 +27,11 @@ class ComplexModel(BaseSegmentationModel):
|
|||
crop_size_constraints: Optional[CropSizeConstraints] = None):
|
||||
"""
|
||||
Creates a new instance of the class.
|
||||
|
||||
:param args: The full model configuration.
|
||||
:param full_channels_list: A vector of channel sizes. First entry is the number of image channels,
|
||||
then all feature channels, then the number of classes.
|
||||
then all feature channels, then the number of classes.
|
||||
|
||||
:param network_definition:
|
||||
:param crop_size_constraints: The size constraints for the training crop size.
|
||||
"""
|
||||
|
|
|
@ -59,6 +59,7 @@ class MLP(DeviceAwareModule[ScalarItem, torch.Tensor]):
|
|||
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
|
||||
"""
|
||||
Transforms a classification item into a torch.Tensor that the forward pass can consume
|
||||
|
||||
:param item: ClassificationItem
|
||||
:return: Tensor
|
||||
"""
|
||||
|
|
|
@ -26,15 +26,18 @@ class UNet2D(UNet3D):
|
|||
"""
|
||||
Initializes a 2D UNet model, where the input image is expected as a 3 dimensional tensor with a vanishing
|
||||
Z dimension.
|
||||
|
||||
:param input_image_channels: The number of image channels that the model should consume.
|
||||
:param initial_feature_channels: The number of feature maps used in the model in the first convolution layer.
|
||||
Subsequent layers will contain number of feature maps that are multiples of `initial_channels`
|
||||
Subsequent layers will contain number of feature maps that are multiples of `initial_channels`
|
||||
(2^(image_level) * initial_channels)
|
||||
|
||||
:param num_classes: Number of output classes
|
||||
:param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths)
|
||||
:param downsampling_dilation: An additional dilation that is used in the second convolution layer in each
|
||||
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
|
||||
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
|
||||
choice is (1, 2, 2), to increase the receptive field only in X and Y.
|
||||
|
||||
:param padding_mode: The type of padding that should be applied.
|
||||
"""
|
||||
super().__init__(input_image_channels=input_image_channels,
|
||||
|
|
|
@ -30,6 +30,7 @@ class UNet3D(BaseSegmentationModel):
|
|||
4) Support for more downsampling operations to capture larger image context and improve the performance.
|
||||
The network has `num_downsampling_paths` downsampling steps on the encoding side and same number upsampling steps
|
||||
on the decoding side.
|
||||
|
||||
:param num_downsampling_paths: Number of downsampling paths used in Unet model (default 4 image level are used)
|
||||
:param num_classes: Number of output segmentation classes
|
||||
:param kernel_size: Spatial support of convolution kernels used in Unet model
|
||||
|
@ -39,11 +40,15 @@ class UNet3D(BaseSegmentationModel):
|
|||
"""
|
||||
Implements upsampling block for UNet architecture. The operations carried out on the input tensor are
|
||||
1) Upsampling via strided convolutions 2) Concatenating the skip connection tensor 3) Two convolution layers
|
||||
|
||||
:param channels: A tuple containing the number of input and output channels
|
||||
:param upsample_kernel_size: Spatial support of upsampling kernels. If an integer is provided, the same value
|
||||
will be repeated for all three dimensions. For non-cubic kernels please pass a list or tuple with three elements
|
||||
will be repeated for all three dimensions. For non-cubic kernels please pass a list or tuple with three
|
||||
elements.
|
||||
|
||||
:param upsampling_stride: Upsamling factor used in deconvolutional layer. Similar to the `upsample_kernel_size`
|
||||
parameter, if an integer is passed, the same upsampling factor will be used for all three dimensions.
|
||||
parameter, if an integer is passed, the same upsampling factor will be used for all three dimensions.
|
||||
|
||||
:param activation: Linear/Non-linear activation function that is used after linear deconv/conv mappings.
|
||||
:param depth: The depth inside the UNet at which the layer operates. This is only for diagnostic purposes.
|
||||
"""
|
||||
|
@ -120,11 +125,14 @@ class UNet3D(BaseSegmentationModel):
|
|||
Implements a EncodeBlock for UNet.
|
||||
A EncodeBlock is two BasicLayers without dilation and with same padding.
|
||||
The first of those BasicLayer can use stride > 1, and hence will downsample.
|
||||
|
||||
:param channels: A list containing two elements representing the number of input and output channels
|
||||
:param kernel_size: Spatial support of convolution kernels. If an integer is provided, the same value will
|
||||
be repeated for all three dimensions. For non-cubic kernels please pass a tuple with three elements.
|
||||
be repeated for all three dimensions. For non-cubic kernels please pass a tuple with three elements.
|
||||
|
||||
:param downsampling_stride: Downsampling factor used in the first convolutional layer. If an integer is
|
||||
passed, the same downsampling factor will be used for all three dimensions.
|
||||
passed, the same downsampling factor will be used for all three dimensions.
|
||||
|
||||
:param dilation: Dilation of convolution kernels - If set to > 1, kernels capture content from wider range.
|
||||
:param activation: Linear/Non-linear activation function that is used after linear convolution mappings.
|
||||
:param use_residual: If set to True, block2 learns the residuals while preserving the output of block1
|
||||
|
@ -182,18 +190,22 @@ class UNet3D(BaseSegmentationModel):
|
|||
crop_size_constraints=crop_size_constraints)
|
||||
"""
|
||||
Modified 3D-Unet Class
|
||||
|
||||
:param input_image_channels: Number of image channels (scans) that are fed into the model.
|
||||
:param initial_feature_channels: Number of feature-maps used in the model - Subsequent layers will contain
|
||||
number
|
||||
number
|
||||
of featuremaps that is multiples of `initial_feature_channels` (e.g. 2^(image_level) * initial_feature_channels)
|
||||
|
||||
:param num_classes: Number of output classes
|
||||
:param kernel_size: Spatial support of conv kernels in each spatial axis.
|
||||
:param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths)
|
||||
:param downsampling_factor: Spatial downsampling factor for each tensor axis (depth, width, height). This will
|
||||
be used as the stride for the first convolution layer in each encoder block.
|
||||
be used as the stride for the first convolution layer in each encoder block.
|
||||
|
||||
:param downsampling_dilation: An additional dilation that is used in the second convolution layer in each
|
||||
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
|
||||
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
|
||||
choice is (1, 2, 2), to increase the receptive field only in X and Y.
|
||||
|
||||
:param crop_size: The size of the crop that should be used for training.
|
||||
"""
|
||||
|
||||
|
|
|
@ -14,15 +14,18 @@ class BasicLayer(torch.nn.Module):
|
|||
"""
|
||||
A Basic Layer applies a 3D convolution and BatchNorm with the given channels, kernel_size, and dilation.
|
||||
The output of BatchNorm layer is passed through an activation function and its output is returned.
|
||||
|
||||
:param channels: Number of input and output channels.
|
||||
:param kernel_size: Spatial support of convolution kernels
|
||||
:param stride: Kernel stride lenght for convolution op
|
||||
:param padding: Feature map padding after convolution op {"constant/zero", "no_padding"}. When it is set to
|
||||
"no_padding", no padding is applied. For "constant", feature-map tensor size is kept the same at the output by
|
||||
"no_padding", no padding is applied. For "constant", feature-map tensor size is kept the same at the output by
|
||||
padding with zeros.
|
||||
|
||||
:param dilation: Kernel dilation used in convolution layer
|
||||
:param use_bias: If set to True, a bias parameter will be added to the layer. Default is set to False as
|
||||
batch normalisation layer has an affine parameter which are used applied after the bias term is added.
|
||||
batch normalisation layer has an affine parameter which are used applied after the bias term is added.
|
||||
|
||||
:param activation: Activation layer (e.g. nonlinearity) to be used after the convolution and batch norm operations.
|
||||
"""
|
||||
|
||||
|
|
|
@ -21,13 +21,16 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
|
|||
super().__init__(smoothing_eps)
|
||||
"""
|
||||
Multi-class cross entropy loss.
|
||||
|
||||
:param class_weight_power: if 1.0, weights the cross-entropy term for each class equally.
|
||||
Class weights are inversely proportional to the number
|
||||
of pixels belonging to each class, raised to class_weight_power
|
||||
|
||||
:param focal_loss_gamma: Gamma term used in focal loss to weight negative log-likelihood term:
|
||||
https://arxiv.org/pdf/1708.02002.pdf equation(4-5).
|
||||
When gamma equals to zero, it is equivalent to standard
|
||||
CE with no class balancing. (Gamma >= 0.0)
|
||||
|
||||
:param ignore_index: Specifies a target value that is ignored and does not contribute
|
||||
to the input gradient
|
||||
"""
|
||||
|
@ -60,6 +63,7 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
|
|||
def get_focal_loss_pixel_weights(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes weights for each pixel/sample inversely proportional to the posterior likelihood.
|
||||
|
||||
:param logits: Logits tensor.
|
||||
:param target: Target label tensor in one-hot encoding.
|
||||
"""
|
||||
|
@ -103,6 +107,7 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
|
|||
Wrapper for multi-class cross entropy function implemented in PyTorch.
|
||||
The implementation supports tensors with arbitrary spatial dimension.
|
||||
Input logits are normalised internally in `F.cross_entropy` function.
|
||||
|
||||
:param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC
|
||||
:param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
|
||||
"""
|
||||
|
|
|
@ -14,6 +14,7 @@ class MixtureLoss(SupervisedLearningCriterion):
|
|||
def __init__(self, components: List[Tuple[float, SupervisedLearningCriterion]]):
|
||||
"""
|
||||
Loss function defined as a weighted mixture (interpolation) of other loss functions.
|
||||
|
||||
:param components: a non-empty list of weights and loss function instances.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -25,6 +26,7 @@ class MixtureLoss(SupervisedLearningCriterion):
|
|||
"""
|
||||
Wrapper for mixture loss function implemented in PyTorch. Arguments should be suitable for the
|
||||
component loss functions, typically:
|
||||
|
||||
:param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC
|
||||
:param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
|
||||
"""
|
||||
|
|
|
@ -24,9 +24,10 @@ class SoftDiceLoss(SupervisedLearningCriterion):
|
|||
"""
|
||||
:param eps: A small constant to smooth Sorensen-Dice Loss function. Additionally, it avoids division by zero.
|
||||
:param apply_softmax: If true, the input to the loss function will be first fed through a Softmax operation.
|
||||
If false, the input to the loss function will be used as is.
|
||||
If false, the input to the loss function will be used as is.
|
||||
|
||||
:param class_weight_power: power to raise 1/C to, where C is the number of voxels in each class. Should be
|
||||
non-negative to help increase accuracy on small structures.
|
||||
non-negative to help increase accuracy on small structures.
|
||||
"""
|
||||
super().__init__()
|
||||
#: Small value to avoid division by zero errors.
|
||||
|
|
|
@ -14,6 +14,7 @@ def move_to_device(input_tensors: List[torch.Tensor],
|
|||
non_blocking: bool = False) -> Iterable[torch.Tensor]:
|
||||
"""
|
||||
Updates the memory location of tensors stored in a list.
|
||||
|
||||
:param input_tensors: List of torch tensors
|
||||
:param target_device: Target device (e.g. cuda:0, cuda:1, etc). If the device is None, the tensors are not moved.
|
||||
:param non_blocking: bool
|
||||
|
@ -40,6 +41,7 @@ def group_layers_with_balanced_memory(inputs: List[torch.nn.Module],
|
|||
summary: Optional[OrderedDict]) -> Generator:
|
||||
"""
|
||||
Groups layers in the model in a balanced way as such each group has similar size of memory requirement
|
||||
|
||||
:param inputs: List of input torch modules.
|
||||
:param num_groups: Number of groups to be produced.
|
||||
:param summary: Model summary of the input layers which is used to retrieve memory requirements.
|
||||
|
@ -121,6 +123,7 @@ def partition_layers(layers: List[torch.nn.Module],
|
|||
target_devices: List[torch.device]) -> None:
|
||||
"""
|
||||
Splits the models into multiple chunks and assigns each sub-model to a particular GPU
|
||||
|
||||
:param layers: The layers to partition
|
||||
:param summary: Model architecture summary to use for partitioning
|
||||
:param target_devices: The devices to partition layers into
|
||||
|
|
|
@ -45,8 +45,9 @@ def create_parser(yaml_file_path: Path) -> ParserResult:
|
|||
Create a parser for all runner arguments, even though we are only using a subset of the arguments.
|
||||
This way, we can get secrets handling in a consistent way.
|
||||
In particular, this will create arguments for
|
||||
--local_dataset
|
||||
--azure_dataset_id
|
||||
|
||||
* ``--local_dataset``
|
||||
* ``--azure_dataset_id``
|
||||
"""
|
||||
parser = create_runner_parser(SegmentationModelBase)
|
||||
NormalizeAndVisualizeConfig.add_args(parser)
|
||||
|
@ -67,9 +68,11 @@ def get_configs(default_model_config: SegmentationModelBase,
|
|||
def main(yaml_file_path: Path) -> None:
|
||||
"""
|
||||
Invoke either by
|
||||
* specifying a model, '--model Lung'
|
||||
* or specifying dataset and normalization parameters separately: --azure_dataset_id=foo --norm_method=None
|
||||
In addition, the arguments '--image_channel' and '--gt_channel' must be specified (see below).
|
||||
|
||||
* specifying a model, ``--model Lung``
|
||||
* or specifying dataset and normalization parameters separately: ``--azure_dataset_id=foo --norm_method=None``
|
||||
|
||||
In addition, the arguments ``--image_channel`` and ``--gt_channel`` must be specified.
|
||||
"""
|
||||
config, runner_config, args = get_configs(SegmentationModelBase(should_validate=False), yaml_file_path)
|
||||
dataset_config = DatasetConfig(name=config.azure_dataset_id,
|
||||
|
|
|
@ -242,6 +242,7 @@ def robust_mean_std(data: np.ndarray) -> Tuple[float, float, float, float]:
|
|||
Computes robust estimates of mean and standard deviation in the given array.
|
||||
The median is the robust estimate for the mean, the standard deviation is computed from the
|
||||
inter-quartile ranges.
|
||||
|
||||
:param data: The data for which mean and std should be computed.
|
||||
:return: A 4-tuple with values (median, robust_std, minimum data value, maximum data value)
|
||||
"""
|
||||
|
@ -272,9 +273,11 @@ def mri_window(image_in: np.ndarray,
|
|||
around the mean of the remaining values and with a range controlled by the standard deviation and the sharpen
|
||||
input parameter. The larger sharpen is, the wider the range. The resulting values are the normalised to the given
|
||||
output_range, with values below and above the range being set the the boundary values.
|
||||
|
||||
:param image_in: The image to normalize.
|
||||
:param mask: Consider only pixel values of the input image for which the mask is non-zero. If None the whole
|
||||
image is considered.
|
||||
image is considered.
|
||||
|
||||
:param output_range: The desired value range of the result image.
|
||||
:param sharpen: number of standard deviation either side of mean to include in the window
|
||||
:param tail: Default 1, allow window range to include more of tail of distribution.
|
||||
|
|
|
@ -49,12 +49,14 @@ class EnsemblePipeline(FullImageInferencePipelineBase):
|
|||
aggregation_type: EnsembleAggregationType) -> InferencePipeline.Result:
|
||||
"""
|
||||
Helper method to aggregate results from multiple inference pipelines, based on the aggregation type provided.
|
||||
|
||||
:param results: inference pipeline results to aggregate. This may be a Generator to prevent multiple large
|
||||
posterior arrays being held at the same time. The first element of the sequence is modified in place to
|
||||
posterior arrays being held at the same time. The first element of the sequence is modified in place to
|
||||
minimize memory use.
|
||||
|
||||
:param aggregation_type: aggregation function to use to combine the results.
|
||||
:return: InferenceResult: contains a Segmentation for each of the classes and their posterior
|
||||
probabilities.
|
||||
probabilities.
|
||||
"""
|
||||
if aggregation_type != EnsembleAggregationType.Average:
|
||||
raise NotImplementedError(f"Ensembling is not implemented for aggregation type: {aggregation_type}")
|
||||
|
@ -78,12 +80,13 @@ class EnsemblePipeline(FullImageInferencePipelineBase):
|
|||
"""
|
||||
Performs a single inference pass for each model in the ensemble, and aggregates the results
|
||||
based on the provided aggregation type.
|
||||
|
||||
:param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X.
|
||||
:param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
|
||||
:param mask: A binary image used to ignore results outside it in format: Z x Y x X.
|
||||
:param patient_id: The identifier of the patient this image belongs to.
|
||||
:return InferenceResult: that contains Segmentation for each of the classes and their posterior
|
||||
probabilities.
|
||||
:return: InferenceResult: that contains Segmentation for each of the classes and their posterior
|
||||
probabilities.
|
||||
"""
|
||||
logging.info(f"Ensembling inference pipelines ({self._get_pipeline_ids()}) "
|
||||
f"predictions for patient: {patient_id}, "
|
||||
|
|
|
@ -56,6 +56,7 @@ class FullImageInferencePipelineBase(InferencePipelineBase):
|
|||
"""
|
||||
Perform connected component analysis to update segmentation with largest
|
||||
connected component based on the configurations
|
||||
|
||||
:param results: inference results to post-process
|
||||
:return: post-processed version of results
|
||||
"""
|
||||
|
@ -199,12 +200,14 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
Creates an instance of the inference pipeline for a given epoch from a stored checkpoint.
|
||||
After loading, the model parameters are checked for NaN and Infinity values.
|
||||
If there is no checkpoint file for the given epoch, return None.
|
||||
|
||||
:param path_to_checkpoint: The path to the checkpoint that we want to load
|
||||
model_config.checkpoint_folder
|
||||
model_config.checkpoint_folder
|
||||
:param model_config: Model related configurations.
|
||||
:param pipeline_id: Numeric identifier for the pipeline (useful for logging when ensembling)
|
||||
:return InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint
|
||||
file for this epoch.
|
||||
|
||||
:return: InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint
|
||||
file for this epoch.
|
||||
"""
|
||||
if not path_to_checkpoint.is_file():
|
||||
# not raising a value error here: This is used to create individual pipelines for ensembles,
|
||||
|
@ -244,7 +247,7 @@ class InferencePipeline(FullImageInferencePipelineBase):
|
|||
:param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
|
||||
:param mask: A binary image used to ignore results outside it in format: Z x Y x X.
|
||||
:param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided).
|
||||
:return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
|
||||
:return: InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
if image_channels is None:
|
||||
|
|
|
@ -140,6 +140,7 @@ class ScalarEnsemblePipeline(ScalarInferencePipelineBase):
|
|||
config: ScalarModelBase) -> ScalarEnsemblePipeline:
|
||||
"""
|
||||
Creates an ensemble pipeline from a list of checkpoints.
|
||||
|
||||
:param paths_to_checkpoint: List of paths to the checkpoints to be recovered.
|
||||
:param config: Model configuration information.
|
||||
:return:
|
||||
|
@ -179,8 +180,9 @@ class ScalarEnsemblePipeline(ScalarInferencePipelineBase):
|
|||
def aggregate_model_outputs(self, model_outputs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Aggregates the forward pass results from the individual models in the ensemble.
|
||||
|
||||
:param model_outputs: List of model outputs for every model in the ensemble.
|
||||
(Number of ensembles) x (batch_size) x 1
|
||||
(Number of ensembles) x (batch_size) x 1
|
||||
"""
|
||||
# aggregate model outputs
|
||||
if self.aggregation_type == EnsembleAggregationType.Average:
|
||||
|
|
|
@ -28,6 +28,7 @@ def is_val_dice(name: str) -> bool:
|
|||
"""
|
||||
Returns true if the given metric name is a Dice score on the validation set,
|
||||
for a class that is not the background class.
|
||||
|
||||
:param name:
|
||||
:return:
|
||||
"""
|
||||
|
@ -37,6 +38,7 @@ def is_val_dice(name: str) -> bool:
|
|||
def get_val_dice_names(metric_names: Iterable[str]) -> List[str]:
|
||||
"""
|
||||
Returns a list of those metric names from the argument that fulfill the is_val_dice predicate.
|
||||
|
||||
:param metric_names:
|
||||
:return:
|
||||
"""
|
||||
|
@ -47,6 +49,7 @@ def plot_loss_per_epoch(metrics: Dict[str, Any], metric_name: str, label: Option
|
|||
"""
|
||||
Adds a plot of loss (y-axis) versus epoch (x-axis) to the current plot, if the metric
|
||||
is present in the metrics dictionary.
|
||||
|
||||
:param metrics: A dictionary of metrics.
|
||||
:param metric_name: The name of the single metric to plot.
|
||||
:param label: The label for the series that will be plotted.
|
||||
|
@ -64,9 +67,10 @@ def plot_loss_per_epoch(metrics: Dict[str, Any], metric_name: str, label: Option
|
|||
def plot_val_dice_per_epoch(metrics: Dict[str, Any]) -> int:
|
||||
"""
|
||||
Creates a plot of all validation Dice scores per epoch, for all classes apart from background.
|
||||
|
||||
:param metrics:
|
||||
:return: The number of series that were plotted in the graph. Can return 0 if the metrics dictionary
|
||||
does not contain any validation Dice score.
|
||||
does not contain any validation Dice score.
|
||||
"""
|
||||
plt.clf()
|
||||
series_count = 0
|
||||
|
@ -83,6 +87,7 @@ def plot_val_dice_per_epoch(metrics: Dict[str, Any]) -> int:
|
|||
def add_legend(series_count: int) -> None:
|
||||
"""
|
||||
Adds a legend to the present plot, with the column layout depending on the number of series.
|
||||
|
||||
:param series_count:
|
||||
:return:
|
||||
"""
|
||||
|
@ -93,6 +98,7 @@ def add_legend(series_count: int) -> None:
|
|||
def resize_and_save(width_inch: int, height_inch: int, filename: PathOrString, dpi: int = 150) -> None:
|
||||
"""
|
||||
Resizes the present figure to the given (width, height) in inches, and saves it to the given filename.
|
||||
|
||||
:param width_inch: The width of the figure in inches.
|
||||
:param height_inch: The height of the figure in inches.
|
||||
:param filename: The filename to save to.
|
||||
|
@ -113,13 +119,17 @@ def plot_image_and_label_contour(image: np.ndarray,
|
|||
"""
|
||||
Creates a plot that shows the given 2D image in greyscale, and overlays a contour that shows
|
||||
where the 'labels' array has value 1.
|
||||
|
||||
:param image: A 2D image
|
||||
:param labels: A binary 2D image, or a list of binary 2D images. A contour will be plotted for each of those
|
||||
binary images.
|
||||
binary images.
|
||||
|
||||
:param contour_arguments: A dictionary of keyword arguments that will be passed directly into matplotlib's
|
||||
contour function. Can also be a list of dictionaries, with one dict per entry in the 'labels' argument.
|
||||
contour function. Can also be a list of dictionaries, with one dict per entry in the 'labels' argument.
|
||||
|
||||
:param image_range: If provided, the image will be plotted using the given range for the color limits.
|
||||
If None, the minimum and maximum image values will be mapped to the endpoints of the color map.
|
||||
If None, the minimum and maximum image values will be mapped to the endpoints of the color map.
|
||||
|
||||
:param plot_file_name: The file name that should be used to save the plot.
|
||||
"""
|
||||
if image.ndim != 2:
|
||||
|
@ -183,6 +193,7 @@ def plot_before_after_statistics(image_before: np.ndarray,
|
|||
that were obtained before and after a transformation of pixel values.
|
||||
The plot contains histograms, box plots, and visualizations of a single XY slice at z_slice.
|
||||
If a mask argument is provided, only the image pixel values inside of the mask will be plotted.
|
||||
|
||||
:param image_before: The first image for which to plot statistics.
|
||||
:param image_after: The second image for which to plot statistics.
|
||||
:param mask: Indicators with 1 for foreground, 0 for background. If None, plot statistics for all image pixels.
|
||||
|
@ -237,10 +248,13 @@ def plot_normalization_result(loaded_images: Sample,
|
|||
The first plot contains pixel value histograms before and after photometric normalization.
|
||||
The second plot contains the normalized image, overlayed with contours for the foreground pixels,
|
||||
at the slice where the foreground has most pixels.
|
||||
|
||||
:param loaded_images: An instance of Sample with the image and the labels. The first channel of the image will
|
||||
be plotted.
|
||||
be plotted.
|
||||
|
||||
:param image_range: The image value range that will be mapped to the color map. If None, the full image range
|
||||
will be mapped to the colormap.
|
||||
will be mapped to the colormap.
|
||||
|
||||
:param normalizer: The photometric normalization that should be applied.
|
||||
:param result_folder: The folder into which the resulting PNG files should be written.
|
||||
:param result_prefix: The prefix for all output filenames.
|
||||
|
@ -283,13 +297,15 @@ def plot_contours_for_all_classes(sample: Sample,
|
|||
"""
|
||||
Creates a plot with the image, the ground truth, and the predicted segmentation overlaid. One plot is created
|
||||
for each class, each plotting the Z slice where the ground truth has most pixels.
|
||||
|
||||
:param sample: The image sample, with the photonormalized image and the ground truth labels.
|
||||
:param segmentation: The predicted segmentation: multi-value, size Z x Y x X.
|
||||
:param foreground_class_names: The names of all classes, excluding the background class.
|
||||
:param result_folder: The folder into which the resulting plot PNG files should be written.
|
||||
:param result_prefix: A string prefix that will be used for all plots.
|
||||
:param image_range: The minimum and maximum image values that will be mapped to the color map ranges.
|
||||
If None, use the actual min and max values.
|
||||
If None, use the actual min and max values.
|
||||
|
||||
:param channel_index: The index of the image channel that should be plotted.
|
||||
:return: The paths to all generated PNG files.
|
||||
"""
|
||||
|
@ -337,6 +353,7 @@ def segmentation_and_groundtruth_plot(prediction: np.ndarray, ground_truth: np.n
|
|||
"""
|
||||
Plot predicted and the ground truth segmentations. Always plots the middle slice (to match surface distance
|
||||
plots), which can sometimes lead to an empty plot.
|
||||
|
||||
:param prediction: 3D volume (X x Y x Z) of predicted segmentation
|
||||
:param ground_truth: 3D volume (X x Y x Z) of ground truth segmentation
|
||||
:param subject_id: ID of subject for annotating plot
|
||||
|
@ -389,6 +406,7 @@ def surface_distance_ground_truth_plot(ct: np.ndarray, ground_truth: np.ndarray,
|
|||
annotator: str = None) -> None:
|
||||
"""
|
||||
Plot surface distances where prediction > 0, with ground truth contour
|
||||
|
||||
:param ct: CT scan
|
||||
:param ground_truth: Ground truth segmentation
|
||||
:param sds_full: Surface distances (full= where prediction > 0)
|
||||
|
@ -471,10 +489,12 @@ def scan_with_transparent_overlay(scan: np.ndarray,
|
|||
information in the range [0, 1]. High values of the `overlay` are shown as opaque red, low values as transparent
|
||||
red.
|
||||
Plots are created in the current axis.
|
||||
|
||||
:param scan: A 3-dimensional image in (Z, Y, X) ordering
|
||||
:param overlay: A 3-dimensional image in (Z, Y, X) ordering, with values between 0 and 1.
|
||||
:param dimension: The array dimension along with the plot should be created. dimension=0 will generate
|
||||
an axial slice.
|
||||
an axial slice.
|
||||
|
||||
:param position: The index in the chosen dimension where the plot should be created.
|
||||
:param spacing: The tuple of voxel spacings, in (Z, Y, X) order.
|
||||
"""
|
||||
|
|
|
@ -73,6 +73,7 @@ def get_dataframe_with_exact_label_matches(metrics_df: pd.DataFrame,
|
|||
The dataframe must have at least the following columns (defined in the LoggingColumns enum):
|
||||
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
|
||||
Any other columns will be ignored.
|
||||
|
||||
:param prediction_target_set_to_match: The set of prediction targets to which each sample is compared
|
||||
:param all_prediction_targets: The entire set of prediction targets on which the model is trained
|
||||
:param thresholds_per_prediction_target: Thresholds per prediction target to decide if model has predicted True or
|
||||
|
@ -139,9 +140,9 @@ def print_metrics_for_thresholded_output_for_all_prediction_targets(csv_to_set_o
|
|||
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).
|
||||
|
||||
:param csv_to_set_optimal_threshold: Csv written during inference time for the val set. This is used to determine
|
||||
the optimal threshold for classification.
|
||||
the optimal threshold for classification.
|
||||
:param csv_to_compute_metrics: Csv written during inference time for the test set. Metrics are calculated for
|
||||
this csv.
|
||||
this csv.
|
||||
:param config: Model config
|
||||
"""
|
||||
|
||||
|
|
|
@ -69,9 +69,11 @@ def read_csv_and_filter_prediction_target(csv: Path, prediction_target: str,
|
|||
|
||||
:param csv: Path to the metrics CSV file. Must contain at least the following columns (defined in the LoggingColumns
|
||||
enum): LoggingColumns.Patient, LoggingColumns.Hue.
|
||||
|
||||
:param prediction_target: Target ("hue") by which to filter.
|
||||
:param crossval_split_index: If specified, filter rows only for the respective run (requires
|
||||
LoggingColumns.CrossValidationSplitIndex).
|
||||
|
||||
:param data_split: If specified, filter rows by Train/Val/Test (requires LoggingColumns.DataSplit).
|
||||
:param epoch: If specified, filter rows for given epoch (default: last epoch only; requires LoggingColumns.Epoch).
|
||||
:return: Filtered dataframe.
|
||||
|
@ -122,9 +124,11 @@ def get_labels_and_predictions(csv: Path, prediction_target: str,
|
|||
|
||||
:param csv: Path to the metrics CSV file. Must contain at least the following columns (defined in the LoggingColumns
|
||||
enum): LoggingColumns.Patient, LoggingColumns.Hue.
|
||||
|
||||
:param prediction_target: Target ("hue") by which to filter.
|
||||
:param crossval_split_index: If specified, filter rows only for the respective run (requires
|
||||
LoggingColumns.CrossValidationSplitIndex).
|
||||
|
||||
:param data_split: If specified, filter rows by Train/Val/Test (requires LoggingColumns.DataSplit).
|
||||
:param epoch: If specified, filter rows for given epoch (default: last epoch only; requires LoggingColumns.Epoch).
|
||||
:return: Filtered labels and model outputs.
|
||||
|
@ -150,6 +154,7 @@ def get_labels_and_predictions_from_dataframe(df: pd.DataFrame) -> LabelsAndPred
|
|||
def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
|
||||
"""
|
||||
Format PR or ROC plot with appropriate title, axis labels, limits, and grid.
|
||||
|
||||
:param plot_type: Either 'pr' or 'roc'.
|
||||
:param ax: Axes object to format.
|
||||
"""
|
||||
|
@ -171,9 +176,11 @@ def plot_pr_and_roc_curves(labels_and_model_outputs: LabelsAndPredictions, axs:
|
|||
plot_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Given labels and model outputs, plot the ROC and PR curves.
|
||||
|
||||
:param labels_and_model_outputs:
|
||||
:param axs: Pair of axes objects onto which to plot the ROC and PR curves, respectively. New axes are created by
|
||||
default.
|
||||
default.
|
||||
|
||||
:param plot_kwargs: Plotting options to be passed to both `ax.plot(...)` calls.
|
||||
"""
|
||||
if axs is None:
|
||||
|
@ -202,15 +209,18 @@ def plot_scores_and_summary(all_labels_and_model_outputs: Sequence[LabelsAndPred
|
|||
|
||||
Each plotted curve is interpolated onto a common horizontal grid, and the median and CI are computed vertically
|
||||
at each horizontal location.
|
||||
|
||||
:param all_labels_and_model_outputs: Collection of ground-truth labels and model predictions (e.g. for various
|
||||
cross-validation runs).
|
||||
cross-validation runs).
|
||||
|
||||
:param scoring_fn: A scoring function mapping a `LabelsAndPredictions` object to X and Y coordinates for plotting.
|
||||
:param interval_width: A value in [0, 1] representing what fraction of the data should be contained in
|
||||
the shaded area. The edges of the interval are `median +/- interval_width/2`.
|
||||
the shaded area. The edges of the interval are `median +/- interval_width/2`.
|
||||
|
||||
:param ax: Axes object onto which to plot (default: use current axes).
|
||||
:return: A tuple of `(line_handles, summary_handle)` to use in setting a legend for the plot: `line_handles` is a
|
||||
list corresponding to the curves for each `LabelsAndPredictions`, and `summary_handle` references the median line
|
||||
and shaded CI area.
|
||||
list corresponding to the curves for each `LabelsAndPredictions`, and `summary_handle` references the median line
|
||||
and shaded CI area.
|
||||
"""
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
|
@ -236,10 +246,12 @@ def plot_pr_and_roc_curves_crossval(all_labels_and_model_outputs: Sequence[Label
|
|||
"""
|
||||
Given a list of LabelsAndPredictions objects, plot the corresponding ROC and PR curves, along with median line and
|
||||
shaded 80% confidence interval (computed over TPRs and precisions for each fixed FPR and recall value).
|
||||
|
||||
:param all_labels_and_model_outputs: Collection of ground-truth labels and model predictions (e.g. for various
|
||||
cross-validation runs).
|
||||
cross-validation runs).
|
||||
|
||||
:param axs: Pair of axes objects onto which to plot the ROC and PR curves, respectively. New axes are created by
|
||||
default.
|
||||
default.
|
||||
"""
|
||||
if axs is None:
|
||||
_, axs = plt.subplots(1, 2)
|
||||
|
@ -276,11 +288,12 @@ def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase,
|
|||
"""
|
||||
Given the CSV written during inference time and the model config, plot the ROC and PR curves for all prediction
|
||||
targets.
|
||||
|
||||
:param metrics_csv: Path to the metrics CSV file.
|
||||
:param config: Model config.
|
||||
:param data_split: Whether to filter the CSV file for Train, Val, or Test results (default: no filtering).
|
||||
:param is_crossval_report: If True, assumes CSV contains results for multiple cross-validation runs and plots the
|
||||
curves with median and confidence intervals. Otherwise, plots curves for a single run.
|
||||
curves with median and confidence intervals. Otherwise, plots curves for a single run.
|
||||
"""
|
||||
for prediction_target in config.target_names:
|
||||
print_header(f"Class: {prediction_target}", level=3)
|
||||
|
@ -300,8 +313,10 @@ def get_metric(predictions_to_set_optimal_threshold: LabelsAndPredictions,
|
|||
optimal_threshold: Optional[float] = None) -> float:
|
||||
"""
|
||||
Given LabelsAndPredictions objects for the validation and test sets, return the specified metric.
|
||||
|
||||
:param predictions_to_set_optimal_threshold: This set of ground truth labels and model predictions is used to
|
||||
determine the optimal threshold for classification.
|
||||
determine the optimal threshold for classification.
|
||||
|
||||
:param predictions_to_compute_metrics: The set of labels and model outputs to calculate metrics for.
|
||||
:param metric: The name of the metric to calculate.
|
||||
:param optimal_threshold: If provided, use this threshold instead of calculating an optimal threshold.
|
||||
|
@ -351,6 +366,7 @@ def get_all_metrics(predictions_to_set_optimal_threshold: LabelsAndPredictions,
|
|||
is_thresholded: bool = False) -> Dict[str, float]:
|
||||
"""
|
||||
Given LabelsAndPredictions objects for the validation and test sets, compute some metrics.
|
||||
|
||||
:param predictions_to_set_optimal_threshold: This is used to determine the optimal threshold for classification.
|
||||
:param predictions_to_compute_metrics: Metrics are calculated for this set.
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
|
@ -379,6 +395,7 @@ def print_metrics(predictions_to_set_optimal_threshold: LabelsAndPredictions,
|
|||
is_thresholded: bool = False) -> None:
|
||||
"""
|
||||
Given LabelsAndPredictions objects for the validation and test sets, print out some metrics.
|
||||
|
||||
:param predictions_to_set_optimal_threshold: This is used to determine the optimal threshold for classification.
|
||||
:param predictions_to_compute_metrics: Metrics are calculated for this set.
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
|
@ -402,16 +419,21 @@ def get_metrics_table_for_prediction_target(csv_to_set_optimal_threshold: Path,
|
|||
|
||||
:param csv_to_set_optimal_threshold: CSV written during inference time for the val set. This is used to determine
|
||||
the optimal threshold for classification.
|
||||
|
||||
:param csv_to_compute_metrics: CSV written during inference time for the test set. Metrics are calculated for
|
||||
this CSV.
|
||||
|
||||
:param config: Model config
|
||||
:param prediction_target: The prediction target for which to compute metrics.
|
||||
:param data_split_to_set_optimal_threshold: Whether to filter the validation CSV file for Train, Val, or Test
|
||||
results (default: no filtering).
|
||||
|
||||
:param data_split_to_compute_metrics: Whether to filter the test CSV file for Train, Val, or Test results
|
||||
(default: no filtering).
|
||||
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
or are floating point numbers.
|
||||
|
||||
:param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and formats the
|
||||
metrics along with means and standard deviations. Otherwise, collect metrics for a single run.
|
||||
:return: Tuple of rows and header, where each row and the header are lists of strings of same length (2 if
|
||||
|
@ -468,15 +490,20 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,
|
|||
|
||||
:param csv_to_set_optimal_threshold: CSV written during inference time for the val set. This is used to determine
|
||||
the optimal threshold for classification.
|
||||
|
||||
:param csv_to_compute_metrics: CSV written during inference time for the test set. Metrics are calculated for
|
||||
this CSV.
|
||||
|
||||
:param config: Model config
|
||||
:param data_split_to_set_optimal_threshold: Whether to filter the validation CSV file for Train, Val, or Test
|
||||
results (default: no filtering).
|
||||
|
||||
:param data_split_to_compute_metrics: Whether to filter the test CSV file for Train, Val, or Test results
|
||||
(default: no filtering).
|
||||
|
||||
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
|
||||
or are floating point numbers.
|
||||
|
||||
:param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and prints the
|
||||
metrics along with means and standard deviations. Otherwise, prints metrics for a single run.
|
||||
"""
|
||||
|
@ -563,11 +590,14 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path,
|
|||
"""
|
||||
Print the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
|
||||
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
|
||||
|
||||
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
|
||||
used to determine the thresholds for predicting labels on the test set. The best and worst
|
||||
performing subjects will not be printed out for this csv.
|
||||
|
||||
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
|
||||
best and worst performing subjects will be printed out.
|
||||
|
||||
:param k: Number of subjects of each category to print out.
|
||||
:param prediction_target: The class label to filter on
|
||||
"""
|
||||
|
@ -605,6 +635,7 @@ def get_image_filepath_from_subject_id(subject_id: str,
|
|||
config: ScalarModelBase) -> List[Path]:
|
||||
"""
|
||||
Return the filepaths for images associated with a subject. If the subject is not found, raises a ValueError.
|
||||
|
||||
:param subject_id: Subject to retrive image for
|
||||
:param dataset: scalar dataset object
|
||||
:param config: model config
|
||||
|
@ -623,6 +654,7 @@ def get_image_labels_from_subject_id(subject_id: str,
|
|||
config: ScalarModelBase) -> List[str]:
|
||||
"""
|
||||
Return the ground truth labels associated with a subject. If the subject is not found, raises a ValueError.
|
||||
|
||||
:param subject_id: Subject to retrive image for
|
||||
:param dataset: scalar dataset object
|
||||
:param config: model config
|
||||
|
@ -657,6 +689,7 @@ def get_image_outputs_from_subject_id(subject_id: str,
|
|||
def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
|
||||
"""
|
||||
Plots a 2D image given the filepath. Returns false if the image could not be plotted (for example, if it was 3D).
|
||||
|
||||
:param filepath: Path to image
|
||||
:param im_width: Display width for image
|
||||
:return: True if image was plotted, False otherwise
|
||||
|
@ -693,6 +726,7 @@ def plot_image_for_subject(subject_id: str,
|
|||
metrics_df: Optional[pd.DataFrame] = None) -> None:
|
||||
"""
|
||||
Given a subject ID, plots the corresponding image.
|
||||
|
||||
:param subject_id: Subject to plot image for
|
||||
:param dataset: scalar dataset object
|
||||
:param im_width: Display width for image
|
||||
|
@ -741,11 +775,14 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa
|
|||
"""
|
||||
Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain)
|
||||
and the top "k" worst predictions (i.e. misclassifications where the model was the most confident).
|
||||
|
||||
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
|
||||
used to determine the thresholds for predicting labels on the test set. The best and worst
|
||||
performing subjects will not be printed out for this csv.
|
||||
|
||||
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
|
||||
best and worst performing subjects will be printed out.
|
||||
|
||||
:param k: Number of subjects of each category to print out.
|
||||
:param prediction_target: The class label to filter on
|
||||
:param config: scalar model config object
|
||||
|
|
|
@ -26,6 +26,7 @@ reports_folder = "reports"
|
|||
def get_ipynb_report_name(report_type: str) -> str:
|
||||
"""
|
||||
Constructs the name of the report (as an ipython notebook).
|
||||
|
||||
:param report_type: suffix describing the report, added to the filename
|
||||
:return:
|
||||
"""
|
||||
|
@ -35,6 +36,7 @@ def get_ipynb_report_name(report_type: str) -> str:
|
|||
def get_html_report_name(report_type: str) -> str:
|
||||
"""
|
||||
Constructs the name of the report (as an html file).
|
||||
|
||||
:param report_type: suffix describing the report, added to the filename
|
||||
:return:
|
||||
"""
|
||||
|
@ -49,6 +51,7 @@ def print_header(message: str, level: int = 2) -> None:
|
|||
"""
|
||||
Displays a message, and afterwards repeat it as Markdown with the given indentation level (level=1 is the
|
||||
outermost, `# Foo`.
|
||||
|
||||
:param message: The header string to display.
|
||||
:param level: The Markdown indentation level. level=1 for top level, level=3 for `### Foo`
|
||||
"""
|
||||
|
@ -59,6 +62,7 @@ def print_header(message: str, level: int = 2) -> None:
|
|||
def print_table(rows: Sequence[Sequence[str]], header: Optional[Sequence[str]] = None) -> None:
|
||||
"""
|
||||
Displays the provided content in a formatted HTML table, with optional column headers.
|
||||
|
||||
:param rows: List of rows, where each row is a list of string-valued cell contents.
|
||||
:param header: List of column headers. If given, this special first row is rendered with emphasis.
|
||||
"""
|
||||
|
@ -74,6 +78,7 @@ def print_table(rows: Sequence[Sequence[str]], header: Optional[Sequence[str]] =
|
|||
def generate_notebook(template_notebook: Path, notebook_params: Dict, result_notebook: Path) -> Path:
|
||||
"""
|
||||
Generates a notebook report as jupyter notebook and html page
|
||||
|
||||
:param template_notebook: path to template notebook
|
||||
:param notebook_params: parameters for the notebook
|
||||
:param result_notebook: the path for the executed notebook
|
||||
|
|
|
@ -40,6 +40,7 @@ def plot_scores_for_csv(path_csv: str, outlier_range: float, max_row_count: int)
|
|||
def display_without_index(df: pd.DataFrame) -> None:
|
||||
"""
|
||||
Prints the given dataframe as HTML via the `display` function, but without the index column.
|
||||
|
||||
:param df: The dataframe to print.
|
||||
"""
|
||||
display(HTML(df.to_html(index=False)))
|
||||
|
@ -52,12 +53,13 @@ def display_metric(df: pd.DataFrame,
|
|||
high_values_are_good: bool) -> None:
|
||||
"""
|
||||
Displays a dataframe with a metric per structure, first showing the
|
||||
|
||||
:param max_row_count: The number of rows to print when showing the lowest score patients
|
||||
:param df: The dataframe with metrics per structure
|
||||
:param metric_name: The metric to sort by.
|
||||
:param outlier_range: The standard deviation range data points must fall outside of to be considered an outlier
|
||||
:param high_values_are_good: If true, high values of the metric indicate good performance. If false, low
|
||||
values indicate good performance.
|
||||
values indicate good performance.
|
||||
"""
|
||||
print_header(metric_name, level=2)
|
||||
# Display with best structures first.
|
||||
|
@ -80,11 +82,13 @@ def worst_patients_and_outliers(df: pd.DataFrame,
|
|||
Prints a dataframe that contains the worst patients by the given metric, and a column indicating whether the
|
||||
performance is so poor that it is considered an outlier: metric value which is outside of
|
||||
outlier_range * standard deviation from the mean.
|
||||
|
||||
:param df: The dataframe with metrics.
|
||||
:param outlier_range: The multiplier for standard deviation when constructing the interval for outliers.
|
||||
:param metric_name: The metric for which the "worst" patients should be computed.
|
||||
:param high_values_are_good: If True, high values for the metric are considered good, and hence low values
|
||||
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
|
||||
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
|
||||
|
||||
:param max_row_count: The maximum number of rows to print.
|
||||
:return:
|
||||
"""
|
||||
|
|
|
@ -69,6 +69,7 @@ def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
|
|||
"""
|
||||
Checks if a folder with a local dataset exists. If it does exist, return the argument converted to a Path instance.
|
||||
If it does not exist, raise a FileNotFoundError.
|
||||
|
||||
:param local_dataset: The dataset folder to check.
|
||||
:return: The local_dataset argument, converted to a Path.
|
||||
"""
|
||||
|
@ -83,6 +84,7 @@ def log_metrics(metrics: Dict[ModelExecutionMode, InferenceMetrics],
|
|||
run_context: Run) -> None:
|
||||
"""
|
||||
Log metrics for each split to the provided run, or the current run context if None provided
|
||||
|
||||
:param metrics: Dictionary of inference results for each split.
|
||||
:param run_context: Run for which to log the metrics to, use the current run context if None provided
|
||||
"""
|
||||
|
@ -111,20 +113,26 @@ class MLRunner:
|
|||
"""
|
||||
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using InnerEye
|
||||
as a package!
|
||||
|
||||
:param model_config: If None, run the training as per the `container` argument (bring-your-own-model). If not
|
||||
None, this is the model configuration for a built-in InnerEye model.
|
||||
None, this is the model configuration for a built-in InnerEye model.
|
||||
|
||||
:param container: The LightningContainer object to use for training. If None, assume that the training is
|
||||
for a built-in InnerEye model.
|
||||
for a built-in InnerEye model.
|
||||
|
||||
:param azure_config: Azure related configurations
|
||||
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
|
||||
it is crucial when using InnerEye as a package or submodule!
|
||||
it is crucial when using InnerEye as a package or submodule!
|
||||
|
||||
:param post_cross_validation_hook: A function to call after waiting for completion of cross validation runs.
|
||||
The function is called with the model configuration and the path to the downloaded and merged metrics files.
|
||||
The function is called with the model configuration and the path to the downloaded and merged metrics files.
|
||||
|
||||
:param model_deployment_hook: an optional function for deploying a model in an application-specific way.
|
||||
If present, it should take a LightningContainer, an AzureConfig, an AzureML Model and a ModelProcessing object
|
||||
If present, it should take a LightningContainer, an AzureConfig, an AzureML Model and a ModelProcessing object
|
||||
as arguments, and return an object of any type.
|
||||
|
||||
:param output_subfolder: If provided, the output folder structure will have an additional subfolder,
|
||||
when running outside AzureML.
|
||||
when running outside AzureML.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
if container is None:
|
||||
|
@ -145,8 +153,9 @@ class MLRunner:
|
|||
"""
|
||||
If the present object is using one of the InnerEye built-in models, create a (fake) container for it
|
||||
and call the setup method. It sets the random seeds, and then creates the actual Lightning modules.
|
||||
|
||||
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
This can be missing when running in unit tests, where the local dataset paths are already populated.
|
||||
"""
|
||||
if self._has_setup_run:
|
||||
return
|
||||
|
@ -404,6 +413,7 @@ class MLRunner:
|
|||
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> None:
|
||||
"""
|
||||
Run inference on the test set for all models that are specified via a LightningContainer.
|
||||
|
||||
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
|
||||
"""
|
||||
if len(checkpoint_paths) != 1:
|
||||
|
@ -461,9 +471,10 @@ class MLRunner:
|
|||
model_proc: ModelProcessing) -> None:
|
||||
"""
|
||||
Run inference on InnerEyeContainer models
|
||||
|
||||
:param checkpoint_paths: Checkpoint paths to initialize model
|
||||
:param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
|
||||
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
|
||||
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
|
||||
"""
|
||||
|
||||
# run full image inference on existing or newly trained model on the training, and testing set
|
||||
|
@ -515,10 +526,11 @@ class MLRunner:
|
|||
"""
|
||||
Registers a new model in the workspace's model registry on AzureML to be deployed further.
|
||||
The AzureML run's tags are updated to describe with information about ensemble creation and the parent run ID.
|
||||
|
||||
:param checkpoint_paths: Checkpoint paths to register.
|
||||
:param model_proc: whether it's a single or ensemble model.
|
||||
:returns Tuple element 1: AML model object, or None if no model could be registered.
|
||||
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
|
||||
:return: Tuple element 1: AML model object, or None if no model could be registered.
|
||||
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
|
||||
"""
|
||||
if self.is_offline_run:
|
||||
raise ValueError("Cannot register models when InnerEye is running outside of AzureML.")
|
||||
|
@ -601,9 +613,11 @@ class MLRunner:
|
|||
extra_code_directory, and all checkpoints in a newly created "checkpoints" folder inside the model.
|
||||
In addition, the name of the present AzureML Python environment will be written to a file, for later use
|
||||
in the inference code.
|
||||
|
||||
:param model_folder: The folder into which all files should be copied.
|
||||
:param checkpoint_paths: A list with absolute paths to checkpoint files. They are expected to be
|
||||
inside of the model's checkpoint folder.
|
||||
inside of the model's checkpoint folder.
|
||||
|
||||
:param python_environment: The Python environment that is used in the present AzureML run.
|
||||
"""
|
||||
|
||||
|
@ -704,6 +718,7 @@ class MLRunner:
|
|||
def wait_for_runs_to_finish(self, delay: int = 60) -> None:
|
||||
"""
|
||||
Wait for cross val runs (apart from the current one) to finish and then aggregate results of all.
|
||||
|
||||
:param delay: How long to wait between polls to AML to get status of child runs
|
||||
"""
|
||||
with logging_section("Waiting for sibling runs"):
|
||||
|
@ -716,7 +731,7 @@ class MLRunner:
|
|||
or cancelled.
|
||||
|
||||
:return: True if all sibling runs of the current run have finished (they either completed successfully,
|
||||
or failed). False if any of them is still pending (running or queued).
|
||||
or failed). False if any of them is still pending (running or queued).
|
||||
"""
|
||||
if (not self.is_offline_run) \
|
||||
and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)):
|
||||
|
|
|
@ -110,12 +110,14 @@ class Runner:
|
|||
"""
|
||||
This class contains the high-level logic to start a training run: choose a model configuration by name,
|
||||
submit to AzureML if needed, or otherwise start the actual training and test loop.
|
||||
|
||||
:param project_root: The root folder that contains all of the source code that should be executed.
|
||||
:param yaml_config_file: The path to the YAML file that contains values to supply into sys.argv.
|
||||
:param post_cross_validation_hook: A function to call after waiting for completion of cross validation runs.
|
||||
The function is called with the model configuration and the path to the downloaded and merged metrics files.
|
||||
The function is called with the model configuration and the path to the downloaded and merged metrics files.
|
||||
|
||||
:param model_deployment_hook: an optional function for deploying a model in an application-specific way.
|
||||
If present, it should take a model config (SegmentationModelBase), an AzureConfig, and an AzureML
|
||||
If present, it should take a model config (SegmentationModelBase), an AzureConfig, and an AzureML
|
||||
Model as arguments, and return an optional Path and a further object of any type.
|
||||
"""
|
||||
|
||||
|
@ -203,7 +205,7 @@ class Runner:
|
|||
The main entry point for training and testing models from the commandline. This chooses a model to train
|
||||
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
|
||||
:return: If submitting to AzureML, returns the model configuration that was used for training,
|
||||
including commandline overrides applied (if any).
|
||||
including commandline overrides applied (if any).
|
||||
"""
|
||||
# Usually, when we set logging to DEBUG, we want diagnostics about the model
|
||||
# build itself, but not the tons of debug information that AzureML submissions create.
|
||||
|
@ -378,8 +380,9 @@ class Runner:
|
|||
def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
|
||||
"""
|
||||
Actually run the AzureML job; this method will typically run on an Azure VM.
|
||||
|
||||
:param azure_run_info: Contains all information about the present run in AzureML, in particular where the
|
||||
datasets are mounted.
|
||||
datasets are mounted.
|
||||
"""
|
||||
# Only set the logging level now. Usually, when we set logging to DEBUG, we want diagnostics about the model
|
||||
# build itself, but not the tons of debug information that AzureML submissions create.
|
||||
|
@ -423,6 +426,7 @@ class Runner:
|
|||
def default_post_cross_validation_hook(config: ModelConfigBase, root_folder: Path) -> None:
|
||||
"""
|
||||
A function to run after cross validation results have been aggregated, before they are uploaded to AzureML.
|
||||
|
||||
:param config: The configuration of the model that should be trained.
|
||||
:param root_folder: The folder with all aggregated and per-split files.
|
||||
"""
|
||||
|
@ -446,7 +450,7 @@ def run(project_root: Path,
|
|||
The main entry point for training and testing models from the commandline. This chooses a model to train
|
||||
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
|
||||
:return: If submitting to AzureML, returns the model configuration that was used for training,
|
||||
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
|
||||
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
|
||||
"""
|
||||
runner = Runner(project_root, yaml_config_file, post_cross_validation_hook, model_deployment_hook)
|
||||
return runner.run()
|
||||
|
|
|
@ -89,10 +89,11 @@ class LabelTransformation(Enum):
|
|||
def get_scaling_transform(max_value: int = 100, min_value: int = 0, last_in_pipeline: bool = True) -> Callable:
|
||||
"""
|
||||
Defines the function to scale labels.
|
||||
|
||||
:param max_value:
|
||||
:param min_value:
|
||||
:param last_in_pipeline: if the transformation is the last
|
||||
in the pipeline it should expect a single label as an argument.
|
||||
in the pipeline it should expect a single label as an argument.
|
||||
Else if returns a list of scaled labels for further transforms.
|
||||
:return: The scaling function
|
||||
"""
|
||||
|
@ -340,6 +341,7 @@ class ScalarModelBase(ModelConfigBase):
|
|||
def filter_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Filter dataframes based on expected values on columns
|
||||
|
||||
:param df: the input dataframe
|
||||
:return: the filtered dataframe
|
||||
"""
|
||||
|
@ -542,16 +544,15 @@ class ScalarModelBase(ModelConfigBase):
|
|||
data_split: ModelExecutionMode) -> None:
|
||||
"""
|
||||
Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers.
|
||||
|
||||
:param logits: The model output before normalization.
|
||||
:param targets: The expected model outputs.
|
||||
:param subject_ids: The subject IDs for the present minibatch.
|
||||
:param is_training: If True, write the metrics as training metrics, otherwise as validation metrics.
|
||||
:param metrics: A dictionary mapping from names of prediction targets to a list of metric computers,
|
||||
as returned by create_metric_computers.
|
||||
:param metrics: A dictionary mapping from names of prediction targets to a list of metric computers, as returned by create_metric_computers.
|
||||
:param logger: An object of type DataframeLogger which can be be used for logging within this function.
|
||||
:param current_epoch: Current epoch number.
|
||||
:param data_split: ModelExecutionMode object indicating if this is the train or validation split.
|
||||
:return:
|
||||
"""
|
||||
per_subject_outputs: List[Tuple[str, str, torch.Tensor, torch.Tensor]] = []
|
||||
for i, (prediction_target, metric_list) in enumerate(metrics.items()):
|
||||
|
@ -592,9 +593,10 @@ def get_non_image_features_dict(default_channels: List[str],
|
|||
Returns the channels dictionary for non-imaging features.
|
||||
|
||||
:param default_channels: the channels to use for all features except the features specified
|
||||
in specific_channels
|
||||
in specific_channels
|
||||
|
||||
:param specific_channels: a dictionary mapping feature names to channels for all features that do
|
||||
not use the default channels
|
||||
not use the default channels
|
||||
"""
|
||||
non_imaging_features_dict = {DEFAULT_KEY: default_channels}
|
||||
if specific_channels is not None:
|
||||
|
|
|
@ -43,13 +43,14 @@ def load_predictions(run_type: SurfaceDistanceRunType, azure_config: AzureConfig
|
|||
) -> List[Segmentation]:
|
||||
"""
|
||||
For each run type (IOV or outliers), instantiate a list of predicted Segmentations and return
|
||||
|
||||
:param run_type: either "iov" or "outliers:
|
||||
:param azure_config: AzureConfig
|
||||
:param model_config: GenericConfig
|
||||
:param execution_mode: ModelExecutionMode: Either Test, Train or Val
|
||||
:param extended_annotators: List of annotators plus model_name to load segmentations for
|
||||
:param outlier_range: The standard deviation from the mean which the points have to be below
|
||||
to be considered an outlier.
|
||||
to be considered an outlier.
|
||||
:return: list of [(subject_id, structure name and dice_scores)]
|
||||
"""
|
||||
predictions = []
|
||||
|
|
|
@ -74,8 +74,9 @@ class CheckpointHandler:
|
|||
Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
|
||||
run_recovery_object, weights_url or local_weights_path.
|
||||
This is called at the start of training.
|
||||
|
||||
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
|
||||
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
|
||||
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
|
||||
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
|
||||
"""
|
||||
if self.azure_config.run_recovery_id:
|
||||
|
@ -246,6 +247,7 @@ def download_folder_from_run_to_temp_folder(folder: str,
|
|||
|
||||
:param run: If provided, download the files from that run. If omitted, download the files from the current run
|
||||
(taken from RUN_CONTEXT)
|
||||
|
||||
:param workspace: The AML workspace where the run is located. If omitted, the hi-ml defaults of finding a workspace
|
||||
are used (current workspace when running in AzureML, otherwise expecting a config.json file)
|
||||
:return: The path to which the files were downloaded. The files are located in that folder, without any further
|
||||
|
@ -276,6 +278,7 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
|
|||
"""
|
||||
Looks at all the checkpoint files and returns the path to the one that should be used for recovery.
|
||||
If no checkpoint files are found on disk, the function attempts to download from the current AzureML run.
|
||||
|
||||
:param path: The folder to start searching in.
|
||||
:return: None if there is no suitable recovery checkpoints, or else a full path to the checkpoint file.
|
||||
"""
|
||||
|
@ -294,6 +297,7 @@ def get_recovery_checkpoint_path(path: Path) -> Path:
|
|||
"""
|
||||
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
|
||||
FileNotFoundError if no recovery checkpoint file is present.
|
||||
|
||||
:param path: Path to checkpoint folder
|
||||
"""
|
||||
recovery_checkpoint = find_recovery_checkpoint(path)
|
||||
|
@ -308,6 +312,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
|||
Finds the checkpoint file in the given path that can be used for re-starting the present job.
|
||||
This can be an autosave checkpoint, or the last checkpoint. All existing checkpoints are loaded, and the one
|
||||
for the highest epoch is used for recovery.
|
||||
|
||||
:param path: The folder to search in.
|
||||
:return: Returns the checkpoint file to use for re-starting, or None if no such file was found.
|
||||
"""
|
||||
|
@ -337,6 +342,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
|
|||
def cleanup_checkpoints(path: Path) -> None:
|
||||
"""
|
||||
Remove autosave checkpoints from the given checkpoint folder, and check if a "last.ckpt" checkpoint is present.
|
||||
|
||||
:param path: The folder that contains all checkpoint files.
|
||||
"""
|
||||
logging.info(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
|
||||
|
@ -359,6 +365,7 @@ def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) ->
|
|||
The checkpoints for the sibling runs will go into folder 'OTHER_RUNS/<cross_validation_split>'
|
||||
in the checkpoint folder. There is special treatment for the child run that is equal to the present AzureML
|
||||
run, its checkpoints will be read off the checkpoint folder as-is.
|
||||
|
||||
:param config: Model related configs.
|
||||
:param run: The Hyperdrive parent run to download from.
|
||||
:return: run recovery information
|
||||
|
@ -392,12 +399,14 @@ def download_all_checkpoints_from_run(config: OutputParams, run: Run,
|
|||
only_return_path: bool = False) -> RunRecovery:
|
||||
"""
|
||||
Downloads all checkpoints of the provided run inside the checkpoints folder.
|
||||
|
||||
:param config: Model related configs.
|
||||
:param run: Run whose checkpoints should be recovered
|
||||
:param subfolder: optional subfolder name, if provided the checkpoints will be downloaded to
|
||||
CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
|
||||
CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
|
||||
|
||||
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
|
||||
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
|
||||
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
|
||||
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
|
||||
:return: run recovery information
|
||||
"""
|
||||
|
|
|
@ -60,6 +60,7 @@ class ModelConfigLoader(GenericConfig):
|
|||
Given a module specification check to see if it has a class property with
|
||||
the <model_name> provided, and instantiate that config class with the
|
||||
provided <config_overrides>. Otherwise, return None.
|
||||
|
||||
:param module_spec:
|
||||
:return: Instantiated model config if it was found.
|
||||
"""
|
||||
|
|
|
@ -39,7 +39,7 @@ def load_csv(csv_path: Path, expected_cols: List[str], col_type_converters: Opti
|
|||
:param csv_path: Path to file
|
||||
:param expected_cols: A list of the columns which must, as a minimum, be present.
|
||||
:param col_type_converters: Dictionary of column: type, which ensures certain DataFrame columns are parsed with
|
||||
specific types
|
||||
specific types
|
||||
:return: Loaded pandas DataFrame
|
||||
"""
|
||||
if not expected_cols:
|
||||
|
@ -61,6 +61,7 @@ def drop_rows_missing_important_values(df: pd.DataFrame, important_cols: List[st
|
|||
"""
|
||||
Remove rows from the DataFrame in which the columns that have been specified by the user as "important" contain
|
||||
null values or only whitespace.
|
||||
|
||||
:param df: DataFrame
|
||||
:param important_cols: Columns which must not contain null values
|
||||
:return: df: DataFrame without the dropped rows.
|
||||
|
@ -83,9 +84,10 @@ def extract_outliers(df: pd.DataFrame, outlier_range: float, outlier_col: str =
|
|||
|
||||
:param df: DataFrame from which to extract the outliers
|
||||
:param outlier_range: The number of standard deviation from the mean which the points have to be apart
|
||||
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
|
||||
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
|
||||
mean + outlier_range * std (if outlier_type is HIGH) or below mean - outlier_range * std (if outlier_type is
|
||||
LOW).
|
||||
|
||||
:param outlier_col: The column from which to calculate outliers, e.g. Dice
|
||||
:param outlier_type: Either LOW (i.e. below accepted range) or HIGH (above accepted range) outliers.
|
||||
:return: DataFrame containing only the outliers
|
||||
|
@ -108,14 +110,16 @@ def mark_outliers(df: pd.DataFrame,
|
|||
Rows that are not considered outliers have an empty string in the new column.
|
||||
Outliers are taken from the column `outlier_col`, that have a value that falls outside of
|
||||
mean +- outlier_range * std.
|
||||
|
||||
:param df: DataFrame from which to extract the outliers
|
||||
:param outlier_range: The number of standard deviation from the mean which the points have to be apart
|
||||
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
|
||||
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
|
||||
mean + outlier_range * std (if outlier_type is HIGH) or below mean - outlier_range * std (if outlier_type is
|
||||
LOW).
|
||||
|
||||
:param outlier_col: The column from which to calculate outliers, e.g. Dice
|
||||
:param high_values_are_good: If True, high values for the metric are considered good, and hence low values
|
||||
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
|
||||
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
|
||||
:return: DataFrame with an additional column `is_outlier`
|
||||
"""
|
||||
if outlier_range < 0:
|
||||
|
@ -137,10 +141,12 @@ def get_worst_performing_outliers(df: pd.DataFrame,
|
|||
"""
|
||||
Returns a sorted list (worst to best) of all the worst performing outliers in the metrics table
|
||||
according to metric provided by outlier_col_name
|
||||
|
||||
:param df: Metrics DataFrame
|
||||
:param outlier_col_name: The column by which to determine outliers
|
||||
:param outlier_range: The standard deviation from the mean which the points have to be below
|
||||
to be considered an outlier.
|
||||
to be considered an outlier.
|
||||
|
||||
:param max_n_outliers: the number of (worst performing) outlier IDs to return.
|
||||
:return: a sorted list (worst to best) of all the worst performing outliers
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
|
|||
def __init__(self, columns_and_possible_categories: OrderedDict[str, List[str]]):
|
||||
"""
|
||||
:param columns_and_possible_categories: Mapping between dataset column names
|
||||
to their possible values. eg: {'Inject': ['True', 'False']}. This is required
|
||||
to their possible values. eg: {'Inject': ['True', 'False']}. This is required
|
||||
to establish the one-hot encoding each of the possible values.
|
||||
"""
|
||||
super().__init__()
|
||||
|
@ -75,7 +75,7 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
|
|||
|
||||
def get_supported_dataset_column_names(self) -> List[str]:
|
||||
"""
|
||||
:returns list of categorical columns that are supported by this encoder
|
||||
:return: list of categorical columns that are supported by this encoder
|
||||
"""
|
||||
return list(self._columns_and_possible_categories.keys())
|
||||
|
||||
|
@ -86,8 +86,8 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
|
|||
of length 3.
|
||||
|
||||
:param feature_name: the name of the column for which to compute the feature
|
||||
length.
|
||||
:returns the feature length i.e. number of possible values for this feature.
|
||||
length.
|
||||
:return: the feature length i.e. number of possible values for this feature.
|
||||
"""
|
||||
return self._feature_length[feature_name]
|
||||
|
||||
|
|
|
@ -24,21 +24,21 @@ class DeviceAwareModule(torch.nn.Module, Generic[T, E]):
|
|||
|
||||
def get_devices(self) -> List[torch.device]:
|
||||
"""
|
||||
:return a list of device ids on which this module
|
||||
is deployed.
|
||||
:return: a list of device ids on which this module
|
||||
is deployed.
|
||||
"""
|
||||
return list({x.device for x in self.parameters()})
|
||||
|
||||
def get_number_trainable_parameters(self) -> int:
|
||||
"""
|
||||
:return the number of trainable parameters in the module.
|
||||
:return: the number of trainable parameters in the module.
|
||||
"""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
|
||||
def is_model_on_gpu(self) -> bool:
|
||||
"""
|
||||
Checks if the model is cuda activated or not
|
||||
:return True if the model is running on the GPU.
|
||||
:return: True if the model is running on the GPU.
|
||||
"""
|
||||
try:
|
||||
cuda_activated = next(self.parameters()).is_cuda
|
||||
|
@ -51,6 +51,7 @@ class DeviceAwareModule(torch.nn.Module, Generic[T, E]):
|
|||
"""
|
||||
Extract the input tensors from a data sample as required
|
||||
by the forward pass of the module.
|
||||
|
||||
:param item: a data sample
|
||||
:return: the correct input tensors for the forward pass
|
||||
"""
|
||||
|
|
|
@ -32,7 +32,7 @@ class FeatureStatistics:
|
|||
|
||||
:param sources: list of data sources
|
||||
:return: a Feature Statistics object storing mean and standard deviation for each non-imaging feature of
|
||||
the dataset.
|
||||
the dataset.
|
||||
"""
|
||||
if len(sources) == 0:
|
||||
raise ValueError("sources must have a length greater than 0")
|
||||
|
@ -88,7 +88,7 @@ class FeatureStatistics:
|
|||
All features that have zero standard deviation (constant features) are left untouched.
|
||||
|
||||
:param sources: list of datasources.
|
||||
:return list of data sources where all non-imaging features are standardized.
|
||||
:return: list of data sources where all non-imaging features are standardized.
|
||||
"""
|
||||
|
||||
def apply_source(source: ScalarDataSource) -> ScalarDataSource:
|
||||
|
|
|
@ -72,9 +72,10 @@ class HDF5Object:
|
|||
def parse_acquisition_date(date: str) -> Optional[datetime]:
|
||||
"""
|
||||
Converts a string representing a date to a datetime object
|
||||
|
||||
:param date: string representing a date
|
||||
:return: converted date, None if the string is invalid for
|
||||
date conversion.
|
||||
date conversion.
|
||||
"""
|
||||
try:
|
||||
return datetime.strptime(date, DATE_FORMAT)
|
||||
|
@ -90,6 +91,7 @@ class HDF5Object:
|
|||
def _load_image(hdf5_data: h5py.File, data_field: HDF5Field) -> np.ndarray:
|
||||
"""
|
||||
Load the volume from the HDF5 file.
|
||||
|
||||
:param hdf5_data: path to the hdf5 file
|
||||
:param data_field: field of the hdf5 file containing the data
|
||||
:return: image as numpy array
|
||||
|
|
|
@ -48,6 +48,7 @@ def get_unit_image_header(spacing: Optional[TupleFloat3] = None) -> ImageHeader:
|
|||
"""
|
||||
Creates an ImageHeader object with the origin at 0, and unit direction. The spacing is set to the argument,
|
||||
defaulting to (1, 1, 1) if not provided.
|
||||
|
||||
:param spacing: The image spacing, as a (Z, Y, X) tuple.
|
||||
"""
|
||||
if not spacing:
|
||||
|
@ -75,7 +76,7 @@ def apply_mask_to_posteriors(posteriors: NumpyOrTorch, mask: NumpyOrTorch) -> Nu
|
|||
|
||||
:param posteriors: image tensors in shape: Batches (optional) x Classes x Z x Y x X
|
||||
:param mask: image tensor in shape: Batches (optional) x Z x Y x X
|
||||
:return posteriors with mask applied
|
||||
:return: posteriors with mask applied
|
||||
"""
|
||||
ml_util.check_size_matches(posteriors, mask, matching_dimensions=[-1, -2, -3])
|
||||
|
||||
|
@ -184,8 +185,9 @@ def _pad_images(images: np.ndarray,
|
|||
|
||||
:param images: the image(s) to be padded, in shape: Z x Y x X or batched in shape: Batches x Z x Y x X.
|
||||
:param padding_vector: padding before and after in each dimension eg: ((2,2), (3,3), (2,0))
|
||||
will pad 4 pixels in Z (2 on each side), 6 pixels in Y (3 on each side)
|
||||
will pad 4 pixels in Z (2 on each side), 6 pixels in Y (3 on each side)
|
||||
and 2 in X (2 on the left and 0 on the right).
|
||||
|
||||
:param padding_mode: a valid numpy padding mode.
|
||||
:return: padded copy of the original image.
|
||||
"""
|
||||
|
@ -208,9 +210,9 @@ def posteriors_to_segmentation(posteriors: NumpyOrTorch) -> NumpyOrTorch:
|
|||
Perform argmax on the class dimension.
|
||||
|
||||
:param posteriors: Confidence maps [0,1] for each patch per class in format: Batches x Class x Z x Y x X
|
||||
or Class x Z x Y x X for non-batched input
|
||||
:returns segmentation: argmaxed posteriors with each voxel belonging to a single class: Batches x Z x Y x X
|
||||
or Z x Y x X for non-batched input
|
||||
or Class x Z x Y x X for non-batched input
|
||||
:return: segmentation: argmaxed posteriors with each voxel belonging to a single class: Batches x Z x Y x X
|
||||
or Z x Y x X for non-batched input
|
||||
"""
|
||||
|
||||
if posteriors is None:
|
||||
|
@ -241,9 +243,11 @@ def largest_connected_components(img: np.ndarray,
|
|||
Select the largest connected binary components (plural) in an image. If deletion_limit is set in which case a
|
||||
component is only deleted (i.e. its voxels are False in the output) if its voxel count as a proportion of all the
|
||||
True voxels in the input is less than deletion_limit.
|
||||
|
||||
:param img: np.ndarray
|
||||
:param deletion_limit: if set, a component is deleted only if its voxel count as a proportion of all the
|
||||
True voxels in the input is less than deletion_limit.
|
||||
True voxels in the input is less than deletion_limit.
|
||||
|
||||
:param class_index: Optional. Can be used to provide a class index for logging purposes if the image contains
|
||||
only pixels from a specific class.
|
||||
"""
|
||||
|
@ -281,9 +285,10 @@ def extract_largest_foreground_connected_component(
|
|||
restrictions: Optional[List[Tuple[int, Optional[float]]]] = None) -> np.ndarray:
|
||||
"""
|
||||
Extracts the largest foreground connected component per class from a multi-label array.
|
||||
|
||||
:param multi_label_array: An array of class assignments, i.e. value c at (z, y, x) is a class c.
|
||||
:param restrictions: restrict processing to a subset of the classes (if provided). Each element is a
|
||||
pair (class_index, threshold) where threshold may be None.
|
||||
pair (class_index, threshold) where threshold may be None.
|
||||
:return: An array of class assignments
|
||||
"""
|
||||
if restrictions is None:
|
||||
|
@ -311,6 +316,7 @@ def merge_masks(masks: np.ndarray) -> np.ndarray:
|
|||
"""
|
||||
Merges a one-hot encoded mask tensor (Classes x Z x Y x X) into a multi-label map with labels corresponding to their
|
||||
index in the original tensor of shape (Z x Y x X).
|
||||
|
||||
:param masks: array of shape (Classes x Z x Y x X) containing the mask for each class
|
||||
:return: merged_mask of shape (Z x Y x X).
|
||||
"""
|
||||
|
@ -346,7 +352,7 @@ def multi_label_array_to_binary(array: np.ndarray, num_classes_including_backgro
|
|||
|
||||
:param array: An array of class assignments.
|
||||
:param num_classes_including_background: The number of class assignments to search for. If 3 classes,
|
||||
the class assignments to search for will be 0, 1, and 2.
|
||||
the class assignments to search for will be 0, 1, and 2.
|
||||
:return: an array of size (num_classes_including_background, array.shape)
|
||||
"""
|
||||
return np.stack(list(binaries_from_multi_label_array(array, num_classes_including_background)))
|
||||
|
@ -368,7 +374,7 @@ def get_center_crop(image: NumpyOrTorch, crop_shape: TupleInt3) -> NumpyOrTorch:
|
|||
|
||||
:param image: The original image to extract crop from
|
||||
:param crop_shape: The shape of the center crop to extract
|
||||
:return the center region as specified by the crop_shape argument.
|
||||
:return: the center region as specified by the crop_shape argument.
|
||||
"""
|
||||
if image is None or crop_shape is None:
|
||||
raise Exception("image and crop_shape must not be None")
|
||||
|
@ -399,6 +405,7 @@ def check_array_range(data: np.ndarray, expected_range: Optional[Range] = None,
|
|||
:param data: The array to check. It can have any size.
|
||||
:param expected_range: The interval that all array elements must fall into. The first entry is the lower
|
||||
bound, the second entry is the upper bound.
|
||||
|
||||
:param error_prefix: A string to use as the prefix for the error message.
|
||||
"""
|
||||
if expected_range is None:
|
||||
|
@ -498,9 +505,9 @@ def compute_uncertainty_map_from_posteriors(posteriors: np.ndarray) -> np.ndarra
|
|||
Normalized Shannon Entropy: https://en.wiktionary.org/wiki/Shannon_entropy
|
||||
|
||||
:param posteriors: Normalized probability distribution in range [0, 1] for each class,
|
||||
in shape: Class x Z x Y x X
|
||||
in shape: Class x Z x Y x X
|
||||
:return: Shannon Entropy for each voxel, shape: Z x Y x X expected range is [0,1] where 1 represents
|
||||
low confidence or uniform posterior distribution across classes.
|
||||
low confidence or uniform posterior distribution across classes.
|
||||
"""
|
||||
check_if_posterior_array(posteriors)
|
||||
|
||||
|
@ -513,10 +520,11 @@ def gaussian_smooth_posteriors(posteriors: np.ndarray, kernel_size_mm: TupleFloa
|
|||
Performs Gaussian smoothing on posteriors
|
||||
|
||||
:param posteriors: Normalized probability distribution in range [0, 1] for each class,
|
||||
in shape: Class x Z x Y x X
|
||||
in shape: Class x Z x Y x X
|
||||
|
||||
:param kernel_size_mm: The size of the smoothing kernel in mm to be used in each dimension (Z, Y, X)
|
||||
:param voxel_spacing_mm: Voxel spacing to use to map from mm space to pixel space for the
|
||||
Gaussian sigma parameter for each dimension in (Z x Y x X) order.
|
||||
Gaussian sigma parameter for each dimension in (Z x Y x X) order.
|
||||
:return:
|
||||
"""
|
||||
check_if_posterior_array(posteriors)
|
||||
|
@ -557,10 +565,11 @@ def segmentation_to_one_hot(segmentation: torch.Tensor,
|
|||
|
||||
:param segmentation: A segmentation as a multi-label map of shape [B, C, Z, Y, X]
|
||||
:param use_gpu: If true, and the input is not yet on the GPU, move the intermediate tensors to the GPU. The result
|
||||
will be on the same device as the argument `segmentation`
|
||||
will be on the same device as the argument `segmentation`
|
||||
|
||||
:param result_dtype: The torch data type that the result tensor should have. This would be either float16 or float32
|
||||
:return: A torch tensor with one-hot encoding of the segmentation of shape
|
||||
[B, C*HDF5_NUM_SEGMENTATION_CLASSES, Z, Y, X]
|
||||
[B, C*HDF5_NUM_SEGMENTATION_CLASSES, Z, Y, X]
|
||||
"""
|
||||
|
||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -702,6 +711,7 @@ def apply_summed_probability_rules(model_config: SegmentationModelBase,
|
|||
:param model_config: Model configuration information
|
||||
:param posteriors: Confidences per voxel per class, in format Batch x Classes x Z x Y x X if batched,
|
||||
or Classes x Z x Y x X if not batched.
|
||||
|
||||
:param segmentation: Class labels per voxel, in format Batch x Z x Y x X if batched, or Z x Y x X if not batched.
|
||||
:return: Modified segmentation, as Batch x Z x Y x X if batched, or Z x Y x X if not batched.
|
||||
"""
|
||||
|
|
|
@ -207,6 +207,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
|
|||
|
||||
:param path: The path to the image to load.
|
||||
:return: A numpy array of the image and header data if applicable.
|
||||
|
||||
:param image_type: The type to load the image in, set to None to not cast, default is float
|
||||
:raises ValueError: If the path is invalid or the image is not 3D.
|
||||
"""
|
||||
|
@ -214,6 +215,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
|
|||
def _is_valid_image_path(_path: Path) -> bool:
|
||||
"""
|
||||
Validates a path for an image. Image must be .nii, or .nii.gz.
|
||||
|
||||
:param _path: The path to the file.
|
||||
:return: True if it is valid, False otherwise
|
||||
"""
|
||||
|
@ -241,6 +243,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
|
|||
def load_numpy_image(path: PathOrString, image_type: Optional[Type] = None) -> np.ndarray:
|
||||
"""
|
||||
Loads an array from a numpy file (npz or npy). The array is converted to image_type or untouched if None
|
||||
|
||||
:param path: The path to the numpy file.
|
||||
:param image_type: The dtype to cast the array
|
||||
:return: ndarray
|
||||
|
@ -258,6 +261,7 @@ def load_numpy_image(path: PathOrString, image_type: Optional[Type] = None) -> n
|
|||
def load_dicom_image(path: PathOrString) -> np.ndarray:
|
||||
"""
|
||||
Loads an array from a single dicom file.
|
||||
|
||||
:param path: The path to the dicom file.
|
||||
"""
|
||||
ds = dicom.dcmread(path)
|
||||
|
@ -278,6 +282,7 @@ def load_dicom_image(path: PathOrString) -> np.ndarray:
|
|||
def load_hdf5_dataset_from_file(path_str: Path, dataset_name: str) -> np.ndarray:
|
||||
"""
|
||||
Loads a hdf5 dataset from a file as an ndarray
|
||||
|
||||
:param path_str: The path to the HDF5 file
|
||||
:param dataset_name: The dataset name in the HDF5 file that we want to load
|
||||
:return: ndarray
|
||||
|
@ -292,15 +297,17 @@ def load_hdf5_dataset_from_file(path_str: Path, dataset_name: str) -> np.ndarray
|
|||
def load_hdf5_file(path_str: Union[str, Path], load_segmentation: bool = False) -> HDF5Object:
|
||||
"""
|
||||
Loads a single HDF5 file.
|
||||
|
||||
:param path_str: The path of the HDF5 file that should be loaded.
|
||||
:param load_segmentation: If True, the `segmentation` field of the result object will be populated. If
|
||||
False, the field will be set to None.
|
||||
False, the field will be set to None.
|
||||
:return: HDF5Object
|
||||
"""
|
||||
|
||||
def _is_valid_hdf5_path(_path: Path) -> bool:
|
||||
"""
|
||||
Validates a path for an image
|
||||
|
||||
:param _path:
|
||||
:return:
|
||||
"""
|
||||
|
@ -331,12 +338,14 @@ def load_images_and_stack(files: Iterable[Path],
|
|||
|
||||
:param files: The paths of the files to load.
|
||||
:param load_segmentation: If True it loads segmentation if present on the same file as the image. This is only
|
||||
supported for loading from HDF5 files.
|
||||
supported for loading from HDF5 files.
|
||||
|
||||
:param center_crop_size: If supplied, all loaded images will be cropped to the size given here. The crop will be
|
||||
taken from the center of the image.
|
||||
taken from the center of the image.
|
||||
|
||||
:param image_size: If supplied, all loaded images will be resized immediately after loading.
|
||||
:return: A wrapper class that contains the loaded images, and if load_segmentation is True, also the segmentations
|
||||
that were present in the files.
|
||||
that were present in the files.
|
||||
"""
|
||||
images = []
|
||||
segmentations = []
|
||||
|
@ -420,6 +429,7 @@ def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_
|
|||
In the future, this function will be used to load global class and non-imaging information as well.
|
||||
|
||||
:type image_size: Image size, tuple of integers.
|
||||
|
||||
:param dataset_source: The dataset source for which channels are to be loaded into memory.
|
||||
:param check_exclusive: Check that the labels are mutually exclusive (defaults to True).
|
||||
:return: A label sample object containing ground-truth information.
|
||||
|
@ -467,6 +477,7 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW
|
|||
For segmentation binary |<dataset_name>|<channel index>
|
||||
For segmentation multimap |<dataset_name>|<channel index>|<multimap value>
|
||||
The expected dimensions to be (channel, Z, Y, X)
|
||||
|
||||
:param path: The path to the file
|
||||
:param image_type: The type of the image
|
||||
"""
|
||||
|
@ -625,7 +636,7 @@ def store_binary_mask_as_nifti(image: np.ndarray, header: ImageHeader, file_name
|
|||
:param header: The image header
|
||||
:param file_name: The name of the file for this image.
|
||||
:return: the path to the saved image
|
||||
:raises: when image is not binary
|
||||
:raises Exception: when image is not binary
|
||||
"""
|
||||
if not is_binary_array(image):
|
||||
raise Exception("Array values must be binary.")
|
||||
|
|
|
@ -38,14 +38,17 @@ def get_padding_from_kernel_size(padding: PaddingMode,
|
|||
Returns padding value required for convolution layers based on input kernel size and dilation.
|
||||
|
||||
:param padding: Padding type (Enum) {`zero`, `no_padding`}. Option `zero` is intended to preserve the tensor shape.
|
||||
In `no_padding` option, padding is not applied and the function returns only zeros.
|
||||
In `no_padding` option, padding is not applied and the function returns only zeros.
|
||||
|
||||
:param kernel_size: Spatial support of the convolution kernel. It is used to determine the padding size. This can be
|
||||
a scalar, tuple or array.
|
||||
a scalar, tuple or array.
|
||||
|
||||
:param dilation: Dilation of convolution kernel. It is used to determine the padding size. This can be a scalar,
|
||||
tuple or array.
|
||||
tuple or array.
|
||||
|
||||
:param num_dimensions: The number of dimensions that the returned padding tuple should have, if both
|
||||
kernel_size and dilation are scalars.
|
||||
:return padding value required for convolution layers based on input kernel size and dilation.
|
||||
kernel_size and dilation are scalars.
|
||||
:return: padding value required for convolution layers based on input kernel size and dilation.
|
||||
"""
|
||||
if isinstance(kernel_size, Sized):
|
||||
num_dimensions = len(kernel_size)
|
||||
|
@ -70,8 +73,9 @@ def get_upsampling_kernel_size(downsampling_factor: IntOrTuple3, num_dimensions:
|
|||
https://distill.pub/2016/deconv-checkerboard/
|
||||
|
||||
:param downsampling_factor: downsampling factor use for each dimension of the kernel. Can be
|
||||
either a list of len(num_dimension) with one factor per dimension or an int in which case the
|
||||
either a list of len(num_dimension) with one factor per dimension or an int in which case the
|
||||
same factor will be applied for all dimension.
|
||||
|
||||
:param num_dimensions: number of dimensions of the kernel
|
||||
:return: upsampling_kernel_size
|
||||
"""
|
||||
|
@ -98,6 +102,7 @@ def set_model_to_eval_mode(model: torch.nn.Module) -> Generator:
|
|||
"""
|
||||
Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to
|
||||
what is was before the call.
|
||||
|
||||
:param model: The model to modify.
|
||||
"""
|
||||
old_mode = model.training
|
||||
|
|
|
@ -54,6 +54,7 @@ class MetricsPerPatientWriter:
|
|||
"""
|
||||
Writes the per-patient per-structure metrics to a CSV file.
|
||||
Sorting is done first by structure name, then by Dice score ascending.
|
||||
|
||||
:param file_name: The name of the file to write to.
|
||||
"""
|
||||
df = self.to_data_frame()
|
||||
|
@ -70,7 +71,7 @@ class MetricsPerPatientWriter:
|
|||
|
||||
:param file_path: The name of the file to write to.
|
||||
:param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided.
|
||||
If true, ground truth files are optional and we add a total_patients count column for easy
|
||||
If true, ground truth files are optional and we add a total_patients count column for easy
|
||||
comparison. (Defaults to False.)
|
||||
"""
|
||||
|
||||
|
@ -130,6 +131,7 @@ class MetricsPerPatientWriter:
|
|||
def get_number_of_voxels_per_class(labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Computes the number of voxels for each class in a one-hot label map.
|
||||
|
||||
:param labels: one-hot label map in shape Batches x Classes x Z x Y x X or Classes x Z x Y x X
|
||||
:return: A tensor of shape [Batches x Classes] containing the number of non-zero voxels along Z, Y, X
|
||||
"""
|
||||
|
@ -208,9 +210,9 @@ def binary_classification_accuracy(model_output: Union[torch.Tensor, np.ndarray]
|
|||
:param model_output: A tensor containing model outputs.
|
||||
:param label: A tensor containing class labels.
|
||||
:param threshold: the cut-off probability threshold for predictions. If model_ouput is > threshold, the predicted
|
||||
class is 1 else 0.
|
||||
class is 1 else 0.
|
||||
:return: 1.0 if all predicted classes match the expected classes given in 'labels'. 0.0 if no predicted classes
|
||||
match their labels.
|
||||
match their labels.
|
||||
"""
|
||||
model_output, label = convert_input_and_label(model_output, label)
|
||||
predicted_class = model_output > threshold
|
||||
|
@ -254,7 +256,7 @@ def convert_input_and_label(model_output: Union[torch.Tensor, np.ndarray],
|
|||
label: Union[torch.Tensor, np.ndarray]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Ensures that both model_output and label are tensors of dtype float32.
|
||||
:return a Tuple with model_output, label as float tensors.
|
||||
:return: a Tuple with model_output, label as float tensors.
|
||||
"""
|
||||
if not torch.is_tensor(model_output):
|
||||
model_output = torch.tensor(model_output)
|
||||
|
@ -268,8 +270,9 @@ def is_missing_ground_truth(ground_truth: np.ndarray) -> bool:
|
|||
calculate_metrics_per_class in metrics.py and plot_contours_for_all_classes in plotting.py both
|
||||
check whether there is ground truth missing using this simple check for NaN value at 0, 0, 0.
|
||||
To avoid duplicate code we bring it here as a utility function.
|
||||
|
||||
:param ground_truth: ground truth binary array with dimensions: [Z x Y x X].
|
||||
:param label_id: Integer index of the label to check.
|
||||
:returns: True if the label is missing (signified by NaN), False otherwise.
|
||||
:return: True if the label is missing (signified by NaN), False otherwise.
|
||||
"""
|
||||
return np.isnan(ground_truth[0, 0, 0])
|
||||
|
|
|
@ -82,8 +82,8 @@ def validate_dataset_paths(
|
|||
Validates that the required dataset csv file exists in the given path.
|
||||
|
||||
:param dataset_path: The base path
|
||||
:param custom_dataset_csv : The name of the dataset csv file
|
||||
:raise ValueError if the dataset does not exist.
|
||||
:param custom_dataset_csv: The name of the dataset csv file
|
||||
:raises: ValueError if the dataset does not exist.
|
||||
"""
|
||||
if not dataset_path.is_dir():
|
||||
raise ValueError("The dataset_path argument should be the path to the base directory of the data "
|
||||
|
@ -107,14 +107,17 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
|
|||
:param arg1: The first array to check.
|
||||
:param arg2: The second array to check.
|
||||
:param dim1: The expected number of dimensions of arg1. If zero, no check for number of dimensions will be
|
||||
conducted.
|
||||
conducted.
|
||||
|
||||
:param dim2: The expected number of dimensions of arg2. If zero, no check for number of dimensions will be
|
||||
conducted.
|
||||
conducted.
|
||||
|
||||
:param matching_dimensions: The dimensions along which the two arguments have to match. For example, if
|
||||
arg1.ndim==4 and arg2.ndim==5, matching_dimensions==[3] checks if arg1.shape[3] == arg2.shape[3].
|
||||
arg1.ndim==4 and arg2.ndim==5, matching_dimensions==[3] checks if arg1.shape[3] == arg2.shape[3].
|
||||
|
||||
:param arg1_name: If provided, all error messages will use that string to instead of "arg1"
|
||||
:param arg2_name: If provided, all error messages will use that string to instead of "arg2"
|
||||
:raise ValueError if shapes don't match
|
||||
:raises: ValueError if shapes don't match
|
||||
"""
|
||||
if arg1 is None or arg2 is None:
|
||||
raise Exception("arg1 and arg2 cannot be None.")
|
||||
|
@ -125,10 +128,11 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
|
|||
def check_dim(expected: int, actual_shape: Any, name: str) -> None:
|
||||
"""
|
||||
Check if actual_shape is equal to the expected shape
|
||||
|
||||
:param expected: expected shape
|
||||
:param actual_shape:
|
||||
:param name: variable name
|
||||
:raise ValueError if not the same shape
|
||||
:raises: ValueError if not the same shape
|
||||
"""
|
||||
if len(actual_shape) != expected:
|
||||
raise ValueError("'{}' was expected to have ndim == {}, but is {}. Shape is {}"
|
||||
|
@ -151,6 +155,7 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
|
|||
def set_random_seed(random_seed: int, caller_name: Optional[str] = None) -> None:
|
||||
"""
|
||||
Set the seed for the random number generators of python, numpy, torch.random, and torch.cuda for all gpus.
|
||||
|
||||
:param random_seed: random seed value to set.
|
||||
:param caller_name: name of the caller for logging purposes.
|
||||
"""
|
||||
|
@ -170,8 +175,9 @@ def is_test_from_execution_mode(execution_mode: ModelExecutionMode) -> bool:
|
|||
"""
|
||||
Returns a boolean by checking the execution type. The output is used to determine the properties
|
||||
of the forward pass, e.g. model gradient updates or metric computation.
|
||||
:return True if execution mode is VAL or TEST, False if TRAIN
|
||||
:raise ValueError if the execution mode is invalid
|
||||
|
||||
:return: True if execution mode is VAL or TEST, False if TRAIN
|
||||
:raises ValueError: if the execution mode is invalid
|
||||
"""
|
||||
if execution_mode == ModelExecutionMode.TRAIN:
|
||||
return False
|
||||
|
@ -207,6 +213,6 @@ def is_tensor_nan(tensor: torch.Tensor) -> bool:
|
|||
|
||||
:param tensor: The tensor to check.
|
||||
:return: True if any of the tensor elements is Not a Number, False if all entries are valid numbers.
|
||||
If the tensor is empty, the function returns False.
|
||||
If the tensor is empty, the function returns False.
|
||||
"""
|
||||
return bool(torch.isnan(tensor).any().item())
|
||||
|
|
|
@ -11,6 +11,7 @@ from InnerEye.Common.type_annotations import TupleInt3
|
|||
def random_colour(rng: random.Random) -> TupleInt3:
|
||||
"""
|
||||
Generates a random colour in RGB given a random number generator
|
||||
|
||||
:param rng: Random number generator
|
||||
:return: Tuple with random colour in RGB
|
||||
"""
|
||||
|
@ -23,6 +24,7 @@ def random_colour(rng: random.Random) -> TupleInt3:
|
|||
def generate_random_colours_list(rng: random.Random, size: int) -> List[TupleInt3]:
|
||||
"""
|
||||
Generates a list of random colours in RGB given a random number generator and the size of this list
|
||||
|
||||
:param rng: random number generator
|
||||
:param size: size of the list
|
||||
:return: list of random colours in RGB
|
||||
|
|
|
@ -203,6 +203,7 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware
|
|||
"""
|
||||
Writes a human readable summary of the present model to logging.info, and logs the number of trainable
|
||||
parameters to AzureML.
|
||||
|
||||
:param config: The configuration for the model.
|
||||
:param model: The instantiated Pytorch model.
|
||||
"""
|
||||
|
@ -264,6 +265,7 @@ class ScalarModelInputsAndLabels():
|
|||
def move_to_device(self, device: Union[str, torch.device]) -> None:
|
||||
"""
|
||||
Moves the model_inputs and labels field of the present object to the given device. This is done in-place.
|
||||
|
||||
:param device: The target device.
|
||||
"""
|
||||
self.model_inputs = [t.to(device=device) for t in self.model_inputs]
|
||||
|
@ -274,13 +276,15 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
|
|||
sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
|
||||
"""
|
||||
For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
|
||||
|
||||
:param model: The instantiated PyTorch model.
|
||||
:param target_indices: If this list is non-empty, assume that the model is a sequence model, and build the
|
||||
model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
|
||||
model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
|
||||
assume that the model is a normal (non-sequence) model.
|
||||
|
||||
:param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
|
||||
:return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
|
||||
label tensor, subject IDs, and the data item reconstructed from the data loader output
|
||||
label tensor, subject IDs, and the data item reconstructed from the data loader output
|
||||
"""
|
||||
scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore
|
||||
scalar_item = ScalarItem.from_dict(sample)
|
||||
|
|
|
@ -14,6 +14,7 @@ def get_view_dim_and_origin(plane: Plane) -> Tuple[int, str]:
|
|||
"""
|
||||
Get the axis along which to slice, as well as the orientation of the origin, to ensure images
|
||||
are plotted as expected
|
||||
|
||||
:param plane: the plane in which to plot (i.e. axial, sagittal or coronal)
|
||||
:return:
|
||||
"""
|
||||
|
@ -33,6 +34,7 @@ def get_cropped_axes(image: np.ndarray, boundary_width: int = 5) -> Tuple[slice,
|
|||
"""
|
||||
Return the min and max values on both x and y axes where the image is not empty
|
||||
Method: find the min and max of all non-zero pixels in the image, and add a border
|
||||
|
||||
:param image: the image to be cropped
|
||||
:param boundary_width: number of pixels boundary to add around bounding box
|
||||
:return:
|
||||
|
|
|
@ -50,10 +50,11 @@ def sequences_to_padded_tensor(sequences: List[torch.Tensor],
|
|||
padding_value: float = 0.0) -> torch.Tensor:
|
||||
"""
|
||||
Method to convert possibly unequal length sequences to a padded tensor.
|
||||
|
||||
:param sequences: List of Tensors to pad
|
||||
:param padding_value: Padding value to use, default is 0.0
|
||||
:return: Output tensor with shape B x * where * is the max dimensions from the list of provided tensors.
|
||||
And B is the number of tensors in the list of sequences provided.
|
||||
And B is the number of tensors in the list of sequences provided.
|
||||
"""
|
||||
return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
|
||||
|
||||
|
@ -80,6 +81,7 @@ def get_masked_model_outputs_and_labels(model_output: torch.Tensor,
|
|||
Helper function to get masked model outputs, labels and their associated subject ids. Masking is performed
|
||||
by excluding the NaN model outputs and labels based on a bool mask created using the
|
||||
occurrences of NaN in the labels provided.
|
||||
|
||||
:param model_output: The model output tensor to mask.
|
||||
:param labels: The label tensor to use for mask, and use for masking.
|
||||
:param subject_ids: The associated subject ids.
|
||||
|
@ -125,6 +127,7 @@ def apply_sequence_model_loss(loss_fn: torch.nn.Module,
|
|||
"""
|
||||
Applies a loss function to a model output and labels, when the labels come from sequences with unequal length.
|
||||
Missing sequence elements are masked out.
|
||||
|
||||
:param loss_fn: The loss function to apply to the sequence elements that are present.
|
||||
:param model_output: The model outputs
|
||||
:param labels: The ground truth labels.
|
||||
|
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче