DOC: Add all `InnerEye/ML` docstrings to ReadTheDocs (#783)

* 📝 Create basic for ML API

* 📝 Add ML/configs base doc files

* 📝 Finish ML/configs API

* 📝 Update augmentations

* 📝 Add ML/dataset API docs

* 📝 Add rst skeleton for ML/models

* 📝 Fix docstring missing newlines

* Remove script

* 📝 Finish ML/models API docs

* 📝 Start ML/SSL API. Fix some formatting issues

* 📝 Correct whitespace issues in `:param`

* 📝 Fix whitespace errors on `:return` statements

* 📝 Fix :return: statements

* 📝 Finish ML/SSL API

* 📝 Add ML/utils API docs

* 📝 Add visualizer docs, fix `:raise` indents

* 📝 Fix more issues with the `:raises:` formatting

* ♻️ Restructuring folders

* 📝 Limit API `toctree` depth

* 📝 Add primary InnerEye/ML files API to docs

* 📝 Fix and add `InnerEye/ML/*.py` docs

* ⚰️ Remove weird `settings.json` change

* ♻️ 💡 Address review comments
This commit is contained in:
Peter Hessey 2022-08-16 09:58:38 +01:00 коммит произвёл GitHub
Родитель c1b363e158
Коммит 59214c268e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
175 изменённых файлов: 1482 добавлений и 456 удалений

2
.vscode/settings.json поставляемый
Просмотреть файл

@ -12,5 +12,5 @@
},
"files.trimTrailingWhitespace": true,
"files.trimFinalNewlines": true,
"files.insertFinalNewline": true
"files.insertFinalNewline": true,
}

Просмотреть файл

@ -168,6 +168,7 @@ class AzureConfig(GenericConfig):
"""
Creates an AzureConfig object with default values, with the keys/secrets populated from values in the
given YAML file. If a `project_root` folder is provided, a private settings file is read from there as well.
:param yaml_file_path: Path to the YAML file that contains values to create the AzureConfig
:param project_root: A folder in which to search for a private settings file.
:return: AzureConfig with values populated from the yaml files.
@ -231,6 +232,7 @@ class AzureConfig(GenericConfig):
def fetch_run(self, run_recovery_id: str) -> Run:
"""
Gets an instantiated Run object for a given run recovery ID (format experiment_name:run_id).
:param run_recovery_id: A run recovery ID (format experiment_name:run_id)
"""
return fetch_run(workspace=self.get_workspace(), run_recovery_id=run_recovery_id)

Просмотреть файл

@ -39,6 +39,7 @@ def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
Creates a dictionary with git-related information, like branch and commit ID. The dictionary key is a string
that can be used as a tag on an AzureML run, the dictionary value is the git information. If git information
is passed in via commandline arguments, those take precedence over information read out from the repository.
:param azure_config: An AzureConfig object specifying git-related commandline args.
:return: A dictionary mapping from tag name to git info.
"""
@ -56,6 +57,7 @@ def get_git_tags(azure_config: AzureConfig) -> Dict[str, str]:
def additional_run_tags(azure_config: AzureConfig, commandline_args: str) -> Dict[str, str]:
"""
Gets the set of tags that will be added to the AzureML run as metadata, like git status and user name.
:param azure_config: The configurations for the present AzureML job
:param commandline_args: A string that holds all commandline arguments that were used for the present run.
"""
@ -77,6 +79,7 @@ def additional_run_tags(azure_config: AzureConfig, commandline_args: str) -> Dic
def create_experiment_name(azure_config: AzureConfig) -> str:
"""
Gets the name of the AzureML experiment. This is taken from the commandline, or from the git branch.
:param azure_config: The object containing all Azure-related settings.
:return: The name to use for the AzureML experiment.
"""
@ -104,7 +107,7 @@ def create_dataset_configs(azure_config: AzureConfig,
:param all_dataset_mountpoints: When using the datasets in AzureML, these are the per-dataset mount points.
:param all_local_datasets: The paths for all local versions of the datasets.
:return: A list of DatasetConfig objects, in the same order as datasets were provided in all_azure_dataset_ids,
omitting datasets with an empty name.
omitting datasets with an empty name.
"""
datasets: List[DatasetConfig] = []
num_local = len(all_local_datasets)
@ -147,6 +150,7 @@ def create_runner_parser(model_config_class: type = None) -> argparse.ArgumentPa
"""
Creates a commandline parser, that understands all necessary arguments for running a script in Azure,
plus all arguments for the given class. The class must be a subclass of GenericConfig.
:param model_config_class: A class that contains the model-specific parameters.
:return: An instance of ArgumentParser.
"""
@ -167,11 +171,12 @@ def parse_args_and_add_yaml_variables(parser: ArgumentParser,
"""
Reads arguments from sys.argv, modifies them with secrets from local YAML files,
and parses them using the given argument parser.
:param project_root: The root folder for the whole project. Only used to access a private settings file.
:param parser: The parser to use.
:param yaml_config_file: The path to the YAML file that contains values to supply into sys.argv.
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does not
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
:return: The parsed arguments, and overrides
"""
settings_from_yaml = read_all_settings(yaml_config_file, project_root=project_root)
@ -183,6 +188,7 @@ def parse_args_and_add_yaml_variables(parser: ArgumentParser,
def _create_default_namespace(parser: ArgumentParser) -> Namespace:
"""
Creates an argparse Namespace with all parser-specific default values set.
:param parser: The parser to work with.
:return:
"""
@ -207,10 +213,12 @@ def parse_arguments(parser: ArgumentParser,
Parses a list of commandline arguments with a given parser, and adds additional information read
from YAML files. Returns results broken down into a full arguments dictionary, a dictionary of arguments
that were set to non-default values, and unknown arguments.
:param parser: The parser to use
:param settings_from_yaml: A dictionary of settings read from a YAML config file.
:param fail_on_unknown_args: If True, raise an exception if the parser encounters an argument that it does not
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
recognize. If False, unrecognized arguments will be ignored, and added to the "unknown" field of the parser result.
:param args: Arguments to parse. If not given, use those in sys.argv
:return: The parsed arguments, and overrides
"""
@ -261,6 +269,7 @@ def run_duration_string_to_seconds(s: str) -> Optional[int]:
Parse a string that represents a timespan, and returns it converted into seconds. The string is expected to be
floating point number with a single character suffix s, m, h, d for seconds, minutes, hours, day.
Examples: '3.5h', '2d'. If the argument is an empty string, None is returned.
:param s: The string to parse.
:return: The timespan represented in the string converted to seconds.
"""

Просмотреть файл

@ -47,6 +47,7 @@ def split_recovery_id(id: str) -> Tuple[str, str]:
The argument can be in the format 'experiment_name:run_id',
or just a run ID like user_branch_abcde12_123. In the latter case, everything before the last
two alphanumeric parts is assumed to be the experiment name.
:param id:
:return: experiment name and run name
"""
@ -74,9 +75,10 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
Finds an existing run in an experiment, based on a recovery ID that contains the experiment ID
and the actual RunId. The run can be specified either in the experiment_name:run_id format,
or just the run_id.
:param workspace: the configured AzureML workspace to search for the experiment.
:param run_recovery_id: The Run to find. Either in the full recovery ID format, experiment_name:run_id
or just the run_id
or just the run_id
:return: The AzureML run.
"""
return get_aml_run_from_run_id(aml_workspace=workspace, run_id=run_recovery_id)
@ -85,6 +87,7 @@ def fetch_run(workspace: Workspace, run_recovery_id: str) -> Run:
def fetch_runs(experiment: Experiment, filters: List[str]) -> List[Run]:
"""
Fetch the runs in an experiment.
:param experiment: the experiment to fetch runs from
:param filters: a list of run status to include. Must be subset of [Running, Completed, Failed, Canceled].
:return: the list of runs in the experiment
@ -107,10 +110,11 @@ def fetch_child_runs(
"""
Fetch child runs for the provided runs that have the provided AML status (or fetch all by default)
and have a run_recovery_id tag value set (this is to ignore superfluous AML infrastructure platform runs).
:param run: parent run to fetch child run from
:param status: if provided, returns only child runs with this status
:param expected_number_cross_validation_splits: when recovering child runs from AML hyperdrive
sometimes the get_children function fails to retrieve all children. If the number of child runs
sometimes the get_children function fails to retrieve all children. If the number of child runs
retrieved by AML is lower than the expected number of splits, we try to retrieve them manually.
"""
if is_ensemble_run(run):
@ -159,6 +163,7 @@ def to_azure_friendly_string(x: Optional[str]) -> Optional[str]:
def to_azure_friendly_container_path(path: Path) -> str:
"""
Converts a path an Azure friendly container path by replacing "\\", "//" with "/" so it can be in the form: a/b/c.
:param path: Original path
:return: Converted path
"""
@ -168,6 +173,7 @@ def to_azure_friendly_container_path(path: Path) -> str:
def is_offline_run_context(run_context: Run) -> bool:
"""
Tells if a run_context is offline by checking if it has an experiment associated with it.
:param run_context: Context of the run to check
:return:
"""
@ -177,6 +183,7 @@ def is_offline_run_context(run_context: Run) -> bool:
def get_run_context_or_default(run: Optional[Run] = None) -> Run:
"""
Returns the context of the run, if run is not None. If run is None, returns the context of the current run.
:param run: Run to retrieve context for. If None, retrieve ocntext of current run.
:return: Run context
"""
@ -186,6 +193,7 @@ def get_run_context_or_default(run: Optional[Run] = None) -> Run:
def get_cross_validation_split_index(run: Run) -> int:
"""
Gets the cross validation index from the run's tags or returns the default
:param run: Run context from which to get index
:return: The cross validation split index
"""
@ -204,6 +212,7 @@ def is_cross_validation_child_run(run: Run) -> bool:
"""
Checks the provided run's tags to determine if it is a cross validation child run
(which is the case if the split index >=0)
:param run: Run to check.
:return: True if cross validation run. False otherwise.
"""
@ -213,6 +222,7 @@ def is_cross_validation_child_run(run: Run) -> bool:
def strip_prefix(string: str, prefix: str) -> str:
"""
Returns the string without the prefix if it has the prefix, otherwise the string unchanged.
:param string: Input string.
:param prefix: Prefix to remove from input string.
:return: Input string with prefix removed.
@ -226,6 +236,7 @@ def get_all_environment_files(project_root: Path) -> List[Path]:
"""
Returns a list of all Conda environment files that should be used. This is firstly the InnerEye conda file,
and possibly a second environment.yml file that lives at the project root folder.
:param project_root: The root folder of the code that starts the present training run.
:return: A list with 1 or 2 entries that are conda environment files.
"""
@ -260,6 +271,7 @@ def download_run_output_file(blob_path: Path, destination: Path, run: Run) -> Pa
Downloads a single file from the run's default output directory: DEFAULT_AML_UPLOAD_DIR ("outputs").
For example, if blobs_path = "foo/bar.csv", then the run result file "outputs/foo/bar.csv" will be downloaded
to <destination>/bar.csv (the directory will be stripped off).
:param blob_path: The name of the file to download.
:param run: The AzureML run to download the files from
:param destination: Local path to save the downloaded blob to.
@ -287,6 +299,7 @@ def download_run_outputs_by_prefix(
have a given prefix (folder structure). When saving, the prefix string will be stripped off. For example,
if blobs_prefix = "foo", and the run has a file "outputs/foo/bar.csv", it will be downloaded to destination/bar.csv.
If there is in addition a file "foo.txt", that file will be skipped.
:param blobs_prefix: The prefix for all files in "outputs" that should be downloaded.
:param run: The AzureML run to download the files from.
:param destination: Local path to save the downloaded blobs to.

Просмотреть файл

@ -9,6 +9,7 @@ from typing import Any
def _is_empty(item: Any) -> bool:
"""
Returns True if the argument has length 0.
:param item: Object to check.
:return: True if the argument has length 0. False otherwise.
"""
@ -18,6 +19,7 @@ def _is_empty(item: Any) -> bool:
def _is_empty_or_empty_string_list(item: Any) -> bool:
"""
Returns True if the argument has length 0, or a list with a single element that has length 0.
:param item: Object to check.
:return: True if argument has length 0, or a list with a single element that has length 0. False otherwise.
"""
@ -32,6 +34,7 @@ def value_to_string(x: object) -> str:
"""
Returns a string representation of x, with special treatment of Enums (return their value)
and lists (return comma-separated list).
:param x: Object to convert to string
:return: The string representation of the object.
Special cases: For Enums, returns their value, for lists, returns a comma-separated list.

Просмотреть файл

@ -19,10 +19,11 @@ PYTEST_RESULTS_FILE = Path("test-results-on-azure-ml.xml")
def run_pytest(pytest_mark: str, outputs_folder: Path) -> Tuple[bool, Path]:
"""
Runs pytest on the whole test suite, restricting to the tests that have the given PyTest mark.
:param pytest_mark: The PyTest mark to use for filtering out the tests to run.
:param outputs_folder: The folder into which the test result XML file should be written.
:return: True if PyTest found tests to execute and completed successfully, False otherwise.
Also returns the path to the generated PyTest results file.
Also returns the path to the generated PyTest results file.
"""
from _pytest.main import ExitCode
_outputs_file = outputs_folder / PYTEST_RESULTS_FILE
@ -43,6 +44,7 @@ def download_pytest_result(run: Run, destination_folder: Path = Path.cwd()) -> P
"""
Downloads the pytest result file that is stored in the output folder of the given AzureML run.
If there is no pytest result file, throw an Exception.
:param run: The run from which the files should be read.
:param destination_folder: The folder into which the PyTest result file is downloaded.
:return: The path (folder and filename) of the downloaded file.

Просмотреть файл

@ -28,6 +28,7 @@ class SecretsHandling:
def __init__(self, project_root: Path) -> None:
"""
Creates a new instance of the class.
:param project_root: The root folder of the project that starts the InnerEye run.
"""
self.project_root = project_root
@ -36,8 +37,9 @@ class SecretsHandling:
"""
Reads the secrets from file in YAML format, and returns the contents as a dictionary. The YAML file is expected
in the project root directory.
:param secrets_to_read: The list of secret names to read from the YAML file. These will be converted to
uppercase.
uppercase.
:return: A dictionary with secrets, or None if the file does not exist.
"""
secrets_file = self.project_root / fixed_paths.PROJECT_SECRETS_FILE
@ -57,8 +59,9 @@ class SecretsHandling:
Attempts to read secrets from the project secret file. If there is no secrets file, it returns all secrets
in secrets_to_read read from environment variables. When reading from environment, if an expected
secret is not found, its value will be None.
:param secrets_to_read: The list of secret names to read from the YAML file. These will be converted to
uppercase.
uppercase.
"""
# Read all secrets from a local file if present, and sets the matching environment variables.
# If no secrets file is present, no environment variable is modified or created.
@ -69,9 +72,10 @@ class SecretsHandling:
def get_secret_from_environment(self, name: str, allow_missing: bool = False) -> Optional[str]:
"""
Gets a password or key from the secrets file or environment variables.
:param name: The name of the environment variable to read. It will be converted to uppercase.
:param allow_missing: If true, the function returns None if there is no entry of the given name in
any of the places searched. If false, missing entries will raise a ValueError.
any of the places searched. If false, missing entries will raise a ValueError.
:return: Value of the secret. None, if there is no value and allow_missing is True.
"""
@ -99,10 +103,11 @@ def read_all_settings(project_settings_file: Optional[Path] = None,
the `project_root` folder. Settings in the private settings file
override those in the project settings. Both settings files are expected in YAML format, with an entry called
'variables'.
:param project_settings_file: The first YAML settings file to read.
:param project_root: The folder that can contain a 'InnerEyePrivateSettings.yml' file.
:return: A dictionary mapping from string to variable value. The dictionary key is the union of variable names
found in the two settings files.
found in the two settings files.
"""
private_settings_file = None
if project_root and project_root.is_dir():
@ -117,10 +122,11 @@ def read_settings_and_merge(project_settings_file: Optional[Path] = None,
file is read into a dictionary, then the private settings file is read. Settings in the private settings file
override those in the project settings. Both settings files are expected in YAML format, with an entry called
'variables'.
:param project_settings_file: The first YAML settings file to read.
:param private_settings_file: The second YAML settings file to read. Settings in this file has higher priority.
:return: A dictionary mapping from string to variable value. The dictionary key is the union of variable names
found in the two settings files.
found in the two settings files.
"""
result = dict()
if project_settings_file:
@ -138,6 +144,7 @@ def read_settings_yaml_file(yaml_file: Path) -> Dict[str, Any]:
"""
Reads a YAML file, that is expected to contain an entry 'variables'. Returns the dictionary for the 'variables'
section of the file.
:param yaml_file: The yaml file to read.
:return: A dictionary with the variables from the yaml file.
"""

Просмотреть файл

@ -45,6 +45,7 @@ class AMLTensorBoardMonitorConfig(GenericConfig):
def monitor(monitor_config: AMLTensorBoardMonitorConfig, azure_config: AzureConfig) -> None:
"""
Starts TensorBoard monitoring as per the provided arguments.
:param monitor_config: The config containing information on which runs that need be monitored.
:param azure_config: An AzureConfig object with secrets/keys to access the workspace.
"""
@ -93,9 +94,10 @@ def main(settings_yaml_file: Optional[Path] = None,
"""
Parses the commandline arguments, and based on those, starts the Tensorboard monitoring for the AzureML runs
supplied on the commandline.
:param settings_yaml_file: The YAML file that contains all information for accessing Azure.
:param project_root: The root folder that contains all code for the present run. This is only used to locate
a private settings file InnerEyePrivateSettings.yml.
a private settings file InnerEyePrivateSettings.yml.
"""
monitor_config = AMLTensorBoardMonitorConfig.parse_args()
settings_yaml_file = settings_yaml_file or monitor_config.settings

Просмотреть файл

@ -88,9 +88,10 @@ MINIMUM_INSTANCE_COUNT = 10
def compose_distribution_comparisons(file_contents: List[List[List[str]]]) -> List[str]:
"""
Composes comparisons as detailed above.
:param file_contents: two or more lists of rows, where each "rows" is returned by read_csv_file on
(typically) a statistics.csv file
:return a list of lines to print
(typically) a statistics.csv file
:return: a list of lines to print
"""
value_lists: List[Dict[str, List[float]]] = [parse_values(rows) for rows in file_contents]
return compose_distribution_comparisons_on_lists(value_lists)
@ -121,6 +122,7 @@ def mann_whitney_on_key(key: str, lists: List[List[float]]) -> List[Tuple[Tuple[
Applies Mann-Whitney test to all sets of values (in lists) for the given key,
and return a line of results, paired with some values for ordering purposes.
Member lists with fewer than MINIMUM_INSTANCE_COUNT items are discarded.
:param key: statistic name; "Vol" statistics have mm^3 replaced by cm^3 for convenience.
:param lists: list of lists of values
"""
@ -185,7 +187,7 @@ def roc_value(lst1: List[float], lst2: List[float]) -> float:
:param lst1: a list of numbers
:param lst2: another list of numbers
:return: the proportion of pairs (x, y), where x is from lst1 and y is from lst2, for which
x < y, with x == y counting as half an instance.
x < y, with x == y counting as half an instance.
"""
if len(lst1) == 0 or len(lst2) == 0:
return 0.5
@ -256,6 +258,7 @@ def read_csv_file(input_file: str) -> List[List[str]]:
"""
Reads and returns the contents of a csv file. Empty rows (which can
result from end-of-line mismatches) are dropped.
:param input_file: path to a file in csv format
:return: list of rows from the file
"""
@ -270,7 +273,7 @@ def compare_scores_across_institutions(metrics_file: str, splits_to_use: str = "
:param splits_to_use: a comma-separated list of split names
:param mode_to_use: test, validation etc
:return: a list of comparison lines between pairs of splits. If splits_to_use is non empty,
only pairs involving at least one split from that set are compared.
only pairs involving at least one split from that set are compared.
"""
valid_splits = set(splits_to_use.split(",")) if splits_to_use else None
metrics = pd.read_csv(metrics_file)
@ -327,7 +330,7 @@ def get_arguments(arglist: List[str] = None) -> Tuple[Optional[argparse.Namespac
The value of the "-a" switch is one or more split names; pairs of splits not including these
will not be compared.
:return: parsed arguments and identifier for pattern (1, 2, 3 as above), or None, None if none of the
patterns are followed
patterns are followed
"""
# Use argparse because we want to have mandatory non-switch arguments, which GenericConfig doesn't support.
parser = argparse.ArgumentParser("Run Mann-Whitney tests")

Просмотреть файл

@ -72,6 +72,7 @@ def report_structure_extremes(dataset_dir: str, azure_config: AzureConfig) -> No
Writes structure-extreme lines for the subjects in a directory.
If there are any structures with missing slices, a ValueError is raised after writing all the lines.
This allows a build failure to be triggered when such structures exist.
:param azure_config: An object with all necessary information for accessing Azure.
:param dataset_dir: directory containing subject subdirectories with integer names.
"""
@ -129,7 +130,7 @@ def report_structure_extremes_for_subject(subj_dir: str, series_id: str) -> Iter
"""
:param subj_dir: subject directory, containing <structure>.nii.gz files
:param series_id: series identifier for the subject
Yields a line for every <structure>.nii.gz file in the directory.
Yields a line for every <structure>.nii.gz file in the directory.
"""
subject = os.path.basename(subj_dir)
series_prefix = "" if series_id is None else series_id[:8]
@ -174,7 +175,7 @@ def extent_list(presence: np.array, max_value: int) -> Tuple[List[int], List[str
:param presence: a 1-D array of distinct integers in increasing order.
:param max_value: any integer, not necessarily related to presence
:return: two tuples: (1) a list of the minimum and maximum values of presence, and max_value;
(2) a list of strings, each denoting a missing range of values within "presence".
(2) a list of strings, each denoting a missing range of values within "presence".
"""
if len(presence) == 0:
return [-1, -1, max_value], []

Просмотреть файл

@ -97,6 +97,7 @@ class WilcoxonTestConfig(GenericConfig):
def calculate_statistics(dist1: Dict[str, float], dist2: Dict[str, float], factor: float) -> Dict[str, float]:
"""
Select common pairs and run the hypothesis test.
:param dist1: mapping from keys to scores
:param dist2: mapping from keys (some or all in common with dist1) to scores
:param factor: factor to divide the Wilcoxon "z" value by to determine p-value
@ -131,10 +132,11 @@ def difference_counts(values1: List[float], values2: List[float]) -> Tuple[int,
"""
Returns the number of corresponding pairs from vals1 and vals2
in which val1 > val2 and val2 > val1 respectively.
:param values1: list of values
:param values2: list of values, same length as values1
:return: number of pairs in which first value is greater than second, and number of pairs
in which second is greater than first
in which second is greater than first
"""
n1 = 0
n2 = 0
@ -159,6 +161,7 @@ def evaluate_data_pair(data1: Dict[str, Dict[str, float]], data2: Dict[str, Dict
-> Dict[str, Dict[str, float]]:
"""
Find and compare dice scores for each structure
:param data1: dictionary from structure names, to dictionary from subjects to scores
:param data2: another such dictionary, sharing some structure names
:param is_raw_p_value: whether to use "raw" Wilcoxon z values when calculating p values (rather than reduce)
@ -240,6 +243,7 @@ def wilcoxon_signed_rank_test(args: WilcoxonTestConfig,
"""
Reads data from a csv file, and performs all pairwise comparisons, except if --against was specified,
compare every other run against the "--against" run.
:param args: parsed command line parameters
:param name_shortener: optional function to shorten names to make graphs and tables more legible
"""
@ -262,6 +266,7 @@ def run_wilcoxon_test_on_data(data: Dict[str, Dict[str, Dict[str, float]]],
"""
Performs all pairwise comparisons on the provided data, except if "against" was specified,
compare every other run against the "against" run.
:param data: scores such that data[run][structure][subject] = dice score
:param against: runs to compare against; or None to compare all against all
:param raw: whether to interpret Wilcoxon Z values "raw" or apply a correction

Просмотреть файл

@ -74,9 +74,10 @@ def get_best_epoch_results_path(mode: ModelExecutionMode,
"""
For a given model execution mode, creates the relative results path
in the form BEST_EPOCH_FOLDER_NAME/(Train, Test or Val)
:param mode: model execution mode
:param model_proc: whether this is for an ensemble or single model. If ensemble, we return a different path
to avoid colliding with the results from the single model that may have been created earlier in the same run.
to avoid colliding with the results from the single model that may have been created earlier in the same run.
"""
subpath = Path(BEST_EPOCH_FOLDER_NAME) / mode.value
if model_proc == ModelProcessing.ENSEMBLE_CREATION:
@ -108,6 +109,7 @@ def any_pairwise_larger(items1: Any, items2: Any) -> bool:
def check_is_any_of(message: str, actual: Optional[str], valid: Iterable[Optional[str]]) -> None:
"""
Raises an exception if 'actual' is not any of the given valid values.
:param message: The prefix for the error message.
:param actual: The actual value.
:param valid: The set of valid strings that 'actual' is allowed to take on.
@ -136,8 +138,9 @@ def logging_to_stdout(log_level: Union[int, str] = logging.INFO) -> None:
"""
Instructs the Python logging libraries to start writing logs to stdout up to the given logging level.
Logging will use a timestamp as the prefix, using UTC.
:param log_level: The logging level. All logging message with a level at or above this level will be written to
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
stdout. log_level can be numeric, or one of the pre-defined logging strings (INFO, DEBUG, ...).
"""
log_level = standardize_log_level(log_level)
logger = logging.getLogger()
@ -186,6 +189,7 @@ def logging_to_file(file_path: Path) -> None:
Instructs the Python logging libraries to start writing logs to the given file.
Logging will use a timestamp as the prefix, using UTC. The logging level will be the same as defined for
logging to stdout.
:param file_path: The path and name of the file to write to.
"""
# This function can be called multiple times, and should only add a handler during the first call.
@ -219,6 +223,7 @@ def logging_only_to_file(file_path: Path, stdout_log_level: Union[int, str] = lo
Redirects logging to the specified file, undoing that on exit. If logging is currently going
to stdout, messages at level stdout_log_level or higher (typically ERROR) are also sent to stdout.
Usage: with logging_only_to_file(my_log_path): do_stuff()
:param file_path: file to log to
:param stdout_log_level: mininum level for messages to also go to stdout
"""
@ -253,6 +258,7 @@ def logging_section(gerund: str) -> Generator:
to help people locate particular sections. Usage:
with logging_section("doing this and that"):
do_this_and_that()
:param gerund: string expressing what happens in this section of the log.
"""
from time import time
@ -301,16 +307,19 @@ def check_properties_are_not_none(obj: Any, ignore: Optional[List[str]] = None)
def initialize_instance_variables(func: Callable) -> Callable:
"""
Automatically assigns the input parameters.
Automatically assigns the input parameters. Example usage::
class process:
@initialize_instance_variables
def __init__(self, cmd, reachable=False, user='root'):
pass
p = process('halt', True)
print(p.cmd, p.reachable, p.user)
Outputs::
('halt', True, 'root')
>>> class process:
... @initialize_instance_variables
... def __init__(self, cmd, reachable=False, user='root'):
... pass
>>> p = process('halt', True)
>>> # noinspection PyUnresolvedReferences
>>> p.cmd, p.reachable, p.user
('halt', True, 'root')
"""
names, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func)
@ -354,6 +363,7 @@ def is_gpu_tensor(data: Any) -> bool:
def print_exception(ex: Exception, message: str, logger_fn: Callable = logging.error) -> None:
"""
Prints information about an exception, and the full traceback info.
:param ex: The exception that was caught.
:param message: An additional prefix that is printed before the exception itself.
:param logger_fn: The logging function to use for logging this exception
@ -366,6 +376,7 @@ def print_exception(ex: Exception, message: str, logger_fn: Callable = logging.e
def namespace_to_path(namespace: str, root: PathOrString = repository_root_directory()) -> Path:
"""
Given a namespace (in form A.B.C) and an optional root directory R, create a path R/A/B/C
:param namespace: Namespace to convert to path
:param root: Path to prefix (default is project root)
:return:
@ -377,6 +388,7 @@ def path_to_namespace(path: Path, root: PathOrString = repository_root_directory
"""
Given a path (in form R/A/B/C) and an optional root directory R, create a namespace A.B.C.
If root is provided, then path must be a relative child to it.
:param path: Path to convert to namespace
:param root: Path prefix to remove from namespace (default is project root)
:return:

Просмотреть файл

@ -14,6 +14,7 @@ from InnerEye.Common.type_annotations import PathOrString
def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
"""
Gets the full path to the root directory that holds the present repository.
:param path: if provided, a relative path to append to the absolute path to the repository root.
:return: The full path to the repository's root directory, with symlinks resolved if any.
"""
@ -28,6 +29,7 @@ def repository_root_directory(path: Optional[PathOrString] = None) -> Path:
def repository_parent_directory(path: Optional[PathOrString] = None) -> Path:
"""
Gets the full path to the parent directory that holds the present repository.
:param path: if provided, a relative path to append to the absolute path to the repository root.
:return: The full path to the repository's root directory, with symlinks resolved if any.
"""

Просмотреть файл

@ -99,9 +99,11 @@ class GenericConfig(param.Parameterized):
def __init__(self, should_validate: bool = True, throw_if_unknown_param: bool = False, **params: Any):
"""
Instantiates the config class, ignoring parameters that are not overridable.
:param should_validate: If True, the validate() method is called directly after init.
:param throw_if_unknown_param: If True, raise an error if the provided "params" contains any key that does not
correspond to an attribute of the class.
:param params: Parameters to set.
"""
# check if illegal arguments are passed in
@ -157,6 +159,7 @@ class GenericConfig(param.Parameterized):
"""
Adds all overridable fields of the current class to the given argparser.
Fields that are marked as readonly, constant or private are ignored.
:param parser: Parser to add properties to.
"""
@ -165,6 +168,7 @@ class GenericConfig(param.Parameterized):
Parse a string as a bool. Supported values are case insensitive and one of:
'on', 't', 'true', 'y', 'yes', '1' for True
'off', 'f', 'false', 'n', 'no', '0' for False.
:param x: string to test.
:return: Bool value if string valid, otherwise a ValueError is raised.
"""
@ -179,6 +183,7 @@ class GenericConfig(param.Parameterized):
"""
Given a parameter, get its basic Python type, e.g.: param.Boolean -> bool.
Throw exception if it is not supported.
:param _p: parameter to get type and nargs for.
:return: Type
"""
@ -222,6 +227,7 @@ class GenericConfig(param.Parameterized):
Add a boolean argument.
If the parameter default is False then allow --flag (to set it True) and --flag=Bool as usual.
If the parameter default is True then allow --no-flag (to set it to False) and --flag=Bool as usual.
:param parser: parser to add a boolean argument to.
:param k: argument name.
:param p: boolean parameter.
@ -318,6 +324,7 @@ class GenericConfig(param.Parameterized):
"""
Logs a warning for every parameter whose value is not as given in "values", other than those
in keys_to_ignore.
:param values: override dictionary, parameter names to values
:param keys_to_ignore: set of dictionary keys not to report on
:return: None
@ -347,6 +354,7 @@ def create_from_matching_params(from_object: param.Parameterized, cls_: Type[T])
Creates an object of the given target class, and then copies all attributes from the `from_object` to
the newly created object, if there is a matching attribute. The target class must be a subclass of
param.Parameterized.
:param from_object: The object to read attributes from.
:param cls_: The name of the class for the newly created object.
:return: An instance of cls_

Просмотреть файл

@ -33,6 +33,7 @@ class OutputFolderForTests:
"""
Creates a full path for the given file or folder name relative to the root directory stored in the present
object.
:param file_or_folder_name: Name of file or folder to be created under root_dir
"""
return self.root_dir / file_or_folder_name
@ -40,6 +41,7 @@ class OutputFolderForTests:
def make_sub_dir(self, dir_name: str) -> Path:
"""
Makes a sub directory under root_dir
:param dir_name: Name of subdirectory to be created.
"""
sub_dir_path = self.create_file_or_folder_path(dir_name)

Просмотреть файл

@ -28,6 +28,7 @@ COL_VALUE = "value"
def memory_in_gb(bytes: int) -> float:
"""
Converts a memory amount in bytes to gigabytes.
:param bytes:
:return:
"""
@ -63,6 +64,7 @@ class GpuUtilization:
def max(self, other: GpuUtilization) -> GpuUtilization:
"""
Computes the metric-wise maximum of the two GpuUtilization objects.
:param other:
:return:
"""
@ -103,8 +105,9 @@ class GpuUtilization:
"""
Lists all metrics stored in the present object, as (metric_name, value) pairs suitable for logging in
Tensorboard.
:param prefix: If provided, this string as used as an additional prefix for the metric name itself. If prefix
is "max", the metric would look like "maxLoad_Percent"
is "max", the metric would look like "maxLoad_Percent"
:return: A list of (name, value) tuples.
"""
return [
@ -118,6 +121,7 @@ class GpuUtilization:
def from_gpu(gpu: GPU) -> GpuUtilization:
"""
Creates a GpuUtilization object from data coming from the gputil library.
:param gpu: GPU diagnostic data from gputil.
:return:
"""
@ -145,10 +149,11 @@ class ResourceMonitor(Process):
csv_results_folder: Path):
"""
Creates a process that will monitor CPU and GPU utilization.
:param interval_seconds: The interval in seconds at which usage statistics should be written.
:param tensorboard_folder: The path in which to create a tensorboard logfile.
:param csv_results_folder: The path in which the CSV file with aggregate metrics will be created.
When running in AzureML, this should NOT reside inside the /logs folder.
When running in AzureML, this should NOT reside inside the /logs folder.
"""
super().__init__(name="Resource Monitor", daemon=True)
self._interval_seconds = interval_seconds
@ -163,6 +168,7 @@ class ResourceMonitor(Process):
def log_to_tensorboard(self, label: str, value: float) -> None:
"""
Write a scalar metric value to Tensorboard, marked with the present step.
:param label: The name of the metric.
:param value: The value.
"""
@ -172,6 +178,7 @@ class ResourceMonitor(Process):
"""
Updates the stored GPU utilization metrics with the current status coming from gputil, and logs
them to Tensorboard.
:param gpus: The current utilization information, read from gputil, for all available GPUs.
"""
for gpu in gpus:

Просмотреть файл

@ -14,12 +14,13 @@ def spawn_and_monitor_subprocess(process: str,
"""
Helper function to start a subprocess, passing in a given set of arguments, and monitor it.
Returns the subprocess exit code and the list of lines written to stdout.
:param process: The name and path of the executable to spawn.
:param args: The args to the process.
:param env: The environment variables that the new process will run with. If not provided, copy the
environment from the current process.
environment from the current process.
:return: Return code after the process has finished, and the list of lines that were written to stdout by the
subprocess.
subprocess.
"""
if env is None:
env = dict(os.environ.items())

Просмотреть файл

@ -202,6 +202,7 @@ class CovidDataset(InnerEyeCXRDatasetWithReturnIndex):
def _split_dataset(self, val_split: float, seed: int) -> Tuple[Subset, Subset]:
"""
Implements val - train split.
:param val_split: proportion to use for validation
:param seed: random seed for splitting
:return: dataset_train, dataset_val

Просмотреть файл

@ -37,8 +37,10 @@ class InnerEyeVisionDataModule(VisionDataModule):
:param dataset_cls: class to load the dataset. Expected to inherit from InnerEyeDataClassBaseWithReturnIndex and
VisionDataset. See InnerEyeCIFAR10 for an example.
BEWARE VisionDataModule expects the first positional argument of your class to be the data directory.
:param return_index: whether the return the index in __get_item__, the dataset_cls is expected to implement
this logic.
this logic.
:param train_transforms: transforms to use at training time
:param val_transforms: transforms to use at validation time
:param data_dir: data directory where to find the data
@ -113,7 +115,7 @@ class CombinedDataModule(LightningDataModule):
:param encoder_module: datamodule to use for training of SSL.
:param linear_head_module: datamodule to use for training of linear head on top of frozen encoder. Can use a
different batch size than the encoder module. CombinedDataModule logic will take care of aggregation.
different batch size than the encoder module. CombinedDataModule logic will take care of aggregation.
"""
super().__init__(*args, **kwargs)
self.encoder_module = encoder_module
@ -151,6 +153,7 @@ class CombinedDataModule(LightningDataModule):
"""
Creates a CombinedLoader from the data loaders for the encoder and the linear head.
The cycle mode is chosen such that in all cases the encoder dataset is only cycled through once.
:param encoder_loader: The dataloader to use for the SSL encoder.
:param linear_head_loader: The dataloader to use for the linear head.
"""

Просмотреть файл

@ -30,12 +30,14 @@ def get_ssl_transforms_from_config(config: CfgNode,
:param config: configuration defining which augmentations to apply as well as their intensities.
:param return_two_views_per_sample: if True the resulting transforms will return two versions of each sample they
are called on. If False, simply return one transformed version of the sample centered and cropped.
are called on. If False, simply return one transformed version of the sample centered and cropped.
:param use_training_augmentations_for_validation: If True, use augmentation at validation time too.
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
This is required for SSL validation loss to be meaningful. If False, only apply basic processing step
(no augmentations)
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
train_transforms = create_transforms_from_config(config, apply_augmentations=True,
expand_channels=expand_channels)

Просмотреть файл

@ -63,6 +63,7 @@ def get_encoder_output_dim(
) -> int:
"""
Calculates the output dimension of ssl encoder by making a single forward pass.
:param pl_module: pl encoder module
:param dm: pl datamodule
"""

Просмотреть файл

@ -203,10 +203,10 @@ class SSLContainer(LightningContainer):
Returns torch lightning data module for encoder or linear head
:param is_ssl_encoder_module: whether to return the data module for SSL training or for linear head. If true,
:return transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
than SSL training, hence CombinedDataloader might loop over data several times per epoch).
:return: transforms with two views per sample (batch like (img_v1, img_v2, label)). If False, return only one
view per sample but also return the index of the sample in the dataset (to make sure we don't use twice the same
batch in one training epoch (batch like (index, img_v1, label), as classifier dataloader expected to be shorter
than SSL training, hence CombinedDataloader might loop over data several times per epoch).
"""
datamodule_args = self.datamodule_args[SSLDataModuleType.ENCODER] if is_ssl_encoder_module else \
self.datamodule_args[SSLDataModuleType.LINEAR_HEAD]
@ -237,12 +237,15 @@ class SSLContainer(LightningContainer):
is_ssl_encoder_module: bool) -> Tuple[Any, Any]:
"""
Returns the transformation pipeline for training and validation.
:param augmentation_config: optional yaml config defining strength of augmenentations. Ignored for CIFAR
examples.
examples.
:param dataset_name: name of the dataset, value has to be in SSLDatasetName, determines which transformation
pipeline to return.
pipeline to return.
:param is_ssl_encoder_module: if True the transformation pipeline will yield two versions of the image it is
applied on and it applies the training transformations also at validation time. Note that if your
applied on and it applies the training transformations also at validation time. Note that if your
transformation does not contain any randomness, the pipeline will return two identical copies. If False, it
will return only one transformation.
:return: training transformation pipeline and validation transformation pipeline.

Просмотреть файл

@ -19,11 +19,11 @@ class SSLClassifierContainer(SSLContainer):
"""
This module is used to train a linear classifier on top of a frozen (or not) encoder.
If you are running on AML, you can specify the SSL training run id via the --pretraining_run_recovery_id flag. This
will automatically download the checkpoints for you and take the latest one as the starting weights of your
If you are running on AML, you can specify the SSL training run id via the ``--pretraining_run_recovery_id`` flag.
This will automatically download the checkpoints for you and take the latest one as the starting weights of your
classifier.
If you are running locally, you can specify the path to your SSL weights via the --local_ssl_weights_path parameter.
If you are running locally, you can specify the path to your SSL weights via the ``--local_ssl_weights_path`` flag.
See docs/self_supervised_models.md for more details.
"""

Просмотреть файл

@ -40,15 +40,18 @@ class BYOLInnerEye(pl.LightningModule):
**kwargs: Any) -> None:
"""
Args:
:param num_samples: Number of samples present in training dataset / dataloader.
:param learning_rate: Optimizer learning rate.
:param batch_size: Sample batch size used in gradient updates.
:param encoder_name: Type of CNN encoder used to extract image embeddings. The options are:
{'resnet18', 'resnet50', 'resnet101', 'densenet121'}.
:param warmup_epochs: Number of epochs for scheduler warm up (linear increase from 0 to base_lr).
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not
shrink them.
:param weight_decay: L2-norm weight decay.
"""
super().__init__()
@ -78,10 +81,12 @@ class BYOLInnerEye(pl.LightningModule):
"""
Returns the BYOL loss for a given batch of images, used in validation
and training step.
:param batch: assumed to be a batch a Tuple(List[tensor, tensor, tensor], tensor) to match lightning-bolts
SimCLRTrainDataTransform API; the first tuple element contains a list of three tensor where the two first
SimCLRTrainDataTransform API; the first tuple element contains a list of three tensor where the two first
elements contain two are two strong augmented versions of the original images in the batch and the last
is a milder augmentation (ignored here).
:param batch_idx: index of the batch
:return: BYOL loss
"""
@ -115,7 +120,10 @@ class BYOLInnerEye(pl.LightningModule):
self.train_iters_per_epoch = self.hparams.num_samples // global_batch_size # type: ignore
def configure_optimizers(self) -> Any:
# exclude certain parameters
"""Testing this out
:return: _description_
"""
parameters = self.exclude_from_wt_decay(self.online_network.named_parameters(),
weight_decay=self.hparams.weight_decay) # type: ignore
optimizer = Adam(parameters,

Просмотреть файл

@ -41,6 +41,7 @@ class SimCLRInnerEye(SimCLR):
**kwargs: Any) -> None:
"""
Returns SimCLR pytorch-lightning module, based on lightning-bolts implementation.
:param encoder_name: Image encoder name (predefined models)
:param dataset_name: Dataset name (e.g. cifar10, kaggle, etc.)
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.

Просмотреть файл

@ -38,6 +38,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
:param class_weights: The class weights to use when computing the cross entropy loss. If set to None,
no weighting will be done.
:param length_linear_head_loader: The maximum number of batches in the dataloader for the linear head.
"""
@ -123,6 +124,7 @@ class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
"""
Moves batch to device.
:param device: device to move the batch to.
"""
_, x, y = batch

Просмотреть файл

@ -38,9 +38,10 @@ def load_yaml_augmentation_config(config_path: Path) -> CfgNode:
def create_ssl_encoder(encoder_name: str, use_7x7_first_conv_in_resnet: bool = True) -> torch.nn.Module:
"""
Creates SSL encoder.
:param encoder_name: available choices: resnet18, resnet50, resnet101 and densenet121.
:param use_7x7_first_conv_in_resnet: If True, use a 7x7 kernel (default) in the first layer of resnet.
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not shrink them.
If False, replace first layer by a 3x3 kernel. This is required for small CIFAR 32x32 images to not shrink them.
"""
from pl_bolts.models.self_supervised.resnets import resnet18, resnet50, resnet101
from InnerEye.ML.SSL.encoders import DenseNet121Encoder

Просмотреть файл

@ -20,9 +20,9 @@ def random_select_patch_center(sample: Sample, class_weights: List[float] = None
class.
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return numpy int array (3x1) containing patch center spatial coordinates
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop voxel belongs
to (must sum to 1), uniform distribution assumed if none provided.
:return: numpy int array (3x1) containing patch center spatial coordinates
"""
num_classes = sample.labels.shape[0]
@ -70,9 +70,9 @@ def slicers_for_random_crop(sample: Sample,
:param sample: A set of Image channels, ground truth labels and mask to randomly crop.
:param crop_size: The size of the crop expressed as a list of 3 ints, one per spatial dimension.
:param class_weights: A weighting vector with values [0, 1] to influence the class the center crop
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
voxel belongs to (must sum to 1), uniform distribution assumed if none provided.
:return: Tuple element 1: The slicers that convert the input image to the chosen crop. Tuple element 2: The
indices of the center point of the crop.
indices of the center point of the crop.
:raises ValueError: If there are shape mismatches among the arguments or if the crop size is larger than the image.
"""
shape = sample.image.shape[1:]

Просмотреть файл

@ -23,7 +23,7 @@ class RandomGamma:
def __init__(self, scale: Tuple[float, float]) -> None:
"""
:param scale: a tuple (min_gamma, max_gamma) that specifies the range of possible values to sample the gamma
value from when the transformation is called.
value from when the transformation is called.
"""
self.scale = scale
@ -59,6 +59,7 @@ class AddGaussianNoise:
def __init__(self, p_apply: float, std: float) -> None:
"""
Transformation to add Gaussian noise N(0, std) to an image.
:param: p_apply: probability of applying the transformation.
:param: std: standard deviation of the gaussian noise to add to the image.
"""

Просмотреть файл

@ -34,10 +34,11 @@ class ImageTransformationPipeline:
use_different_transformation_per_channel: bool = False):
"""
:param transforms: List of transformations to apply to images. Supports out of the boxes torchvision transforms
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
as they accept data of arbitrary dimension. You can also define your own transform class but be aware that you
function should expect input of shape [C, Z, H, W] and apply the same transformation to each Z slice.
:param use_different_transformation_per_channel: if True, apply a different version of the augmentation pipeline
for each channel. If False, applies the same transformation to each channel, separately.
for each channel. If False, applies the same transformation to each channel, separately.
"""
self.use_different_transformation_per_channel = use_different_transformation_per_channel
self.pipeline = Compose(transforms) if isinstance(transforms, List) else transforms
@ -93,11 +94,13 @@ def create_transforms_from_config(config: CfgNode,
Defines the image transformations pipeline from a config file. It has been designed for Chest X-Ray
images but it can be used for other types of images data, type of augmentations to use and strength are
expected to be defined in the config. The channel expansion is needed for gray images.
:param config: config yaml file fixing strength and type of augmentation to apply
:param apply_augmentations: if True return transformation pipeline with augmentations. Else,
disable augmentations i.e. only resize and center crop the image.
disable augmentations i.e. only resize and center crop the image.
:param expand_channels: if True the expand channel transformation from InnerEye.ML.augmentations.image_transforms
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
will be added to the transformation passed through the config. This is needed for single channel images as CXR.
"""
transforms: List[Any] = []
if expand_channels:

Просмотреть файл

@ -111,12 +111,12 @@ def download_and_compare_scores(outputs_folder: Path, azure_config: AzureConfig,
"""
:param azure_config: Azure configuration to use for downloading data
:param comparison_blob_storage_paths: list of paths to directories containing metrics.csv and dataset.csv files,
each of the form run_recovery_id/rest_of_path
each of the form run_recovery_id/rest_of_path
:param model_dataset_df: dataframe containing contents of dataset.csv for the current model
:param model_metrics_df: dataframe containing contents of metrics.csv for the current model
:return: a dataframe for all the data (current model and all baselines); whether any comparisons were
done, i.e. whether a valid baseline was found; and the text lines to be written to the Wilcoxon results
file.
done, i.e. whether a valid baseline was found; and the text lines to be written to the Wilcoxon results
file.
"""
comparison_baselines = get_comparison_baselines(outputs_folder, azure_config, comparison_blob_storage_paths)
result = perform_score_comparisons(model_dataset_df, model_metrics_df, comparison_baselines)
@ -195,10 +195,11 @@ def compare_files(expected: Path, actual: Path, csv_relative_tolerance: float =
basis.
:param expected: A file that contains the expected contents. The type of comparison (text or binary) is chosen
based on the extension of this file.
based on the extension of this file.
:param actual: A file that contains the actual contents.
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
If 0.0, do not allow any discrepancy.
If 0.0, do not allow any discrepancy.
:return: An empty string if the files appear identical, or otherwise an error message with details.
"""
@ -257,9 +258,9 @@ def compare_folder_contents(expected_folder: Path,
:param actual_folder: The output folder with the actually produced files.
:param run: An AzureML run
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
If 0.0, do not allow any discrepancy.
If 0.0, do not allow any discrepancy.
:return: A list of human readable error messages, with message and file path. If no errors are found, the list is
empty.
empty.
"""
messages = []
if run and is_offline_run_context(run):
@ -302,7 +303,7 @@ def compare_folders_and_run_outputs(expected: Path, actual: Path, csv_relative_t
:param expected: A folder with files that are expected to be present.
:param actual: The output folder with the actually produced files.
:param csv_relative_tolerance: When comparing CSV files, use this as the maximum allowed relative discrepancy.
If 0.0, do not allow any discrepancy.
If 0.0, do not allow any discrepancy.
"""
if not expected.is_dir():
raise ValueError(f"Folder with expected files does not exist: {expected}")

Просмотреть файл

@ -86,6 +86,7 @@ def create_unique_timestamp_id() -> str:
def get_best_checkpoint_path(path: Path) -> Path:
"""
Given a path and checkpoint, formats a path based on the checkpoint file name format.
:param path to checkpoint folder
"""
return path / LAST_CHECKPOINT_FILE_NAME_WITH_SUFFIX

Просмотреть файл

@ -47,12 +47,12 @@ class CovidModel(ScalarModelBase):
"""
Model to train a CovidDataset model from scratch or finetune from SSL-pretrained model.
For AML you need to provide the run_id of your SSL training job as a command line argument
--pretraining_run_recovery_id=id_of_your_ssl_model, this will download the checkpoints of the run to your
For AML you need to provide the run_id of your SSL training job as a command line argument:
``--pretraining_run_recovery_id=<id_of_your_ssl_model>``. This will download the checkpoints of the run to your
machine and load the corresponding pretrained model.
To recover from a particular checkpoint from your SSL run e.g. "recovery_epoch=499.ckpt" please use the
--name_of_checkpoint argument.
To recover from a particular checkpoint from your SSL run e.g. ``"recovery_epoch=499.ckpt"`` please use the
``--name_of_checkpoint`` argument.
"""
use_pretrained_model = param.Boolean(default=False, doc="If True, start training from a model pretrained with SSL."
"If False, start training a DenseNet model from scratch"
@ -242,9 +242,11 @@ class CovidModel(ScalarModelBase):
Generate a custom report for the Covid model. This report will read the file model_output.csv generated for
the training, validation or test sets and compute both the multiclass accuracy and the accuracy for each of the
hierarchical tasks.
:param report_dir: Directory report is to be written to
:param model_proc: Whether this is a single or ensemble model (model_output.csv will be located in different
paths for single vs ensemble runs.)
paths for single vs ensemble runs.)
"""
label_prefix = LoggingColumns.Label.value
@ -364,6 +366,11 @@ class CovidModel(ScalarModelBase):
class DicomPreparation:
def __call__(self, item: torch.Tensor) -> PIL.Image:
"""Call class as a function. This will act as a transformation function for the dataset.
:param item: tensor to transform.
:return: transformed data.
"""
# Item will be of dimension [C, Z, X, Y]
images = item.numpy()
assert images.shape[0] == 1 and images.shape[1] == 1

Просмотреть файл

@ -35,8 +35,9 @@ class HelloDataset(Dataset):
def __init__(self, raw_data: List[List[float]]) -> None:
"""
Creates the 1-dim regression dataset.
:param raw_data: The raw data, e.g. from a cross validation split or loaded from file. This
must be numeric data which can be converted into a tensor. See the static method
must be numeric data which can be converted into a tensor. See the static method
from_path_and_indexes for an example call.
"""
super().__init__()
@ -55,6 +56,7 @@ class HelloDataset(Dataset):
end_index: int) -> 'HelloDataset':
'''
Static method to instantiate a HelloDataset from the root folder with the start and end indexes.
:param root_folder: The folder in which the data file lives ("hellocontainer.csv")
:param start_index: The first row to read.
:param end_index: The last row to read (exclusive)
@ -137,6 +139,7 @@ class HelloRegression(LightningModule):
This method is part of the standard PyTorch Lightning interface. For an introduction, please see
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It runs a forward pass of a tensor through the model.
:param x: The input tensor(s)
:return: The model output.
"""
@ -148,6 +151,7 @@ class HelloRegression(LightningModule):
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It consumes a minibatch of training data (coming out of the data loader), does forward propagation, and
computes the loss.
:param batch: The batch of training data
:return: The loss value with a computation graph attached.
"""
@ -162,6 +166,7 @@ class HelloRegression(LightningModule):
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It consumes a minibatch of validation data (coming out of the data loader), does forward propagation, and
computes the loss.
:param batch: The batch of validation data
:return: The loss value on the validation data.
"""
@ -173,6 +178,7 @@ class HelloRegression(LightningModule):
"""
This is a convenience method to reduce code duplication, because training, validation, and test step share
large amounts of code.
:param batch: The batch of data to process, with input data and targets.
:return: The MSE loss that the model achieved on this batch.
"""
@ -207,6 +213,7 @@ class HelloRegression(LightningModule):
https://pytorch-lightning.readthedocs.io/en/stable/starter/converting.html
It evaluates the model in "inference mode" on data coming from the test set. It could, for example,
also write each model prediction to disk.
:param batch: The batch of test data.
:param batch_idx: The index (0, 1, ...) of the batch when the data loader is enumerated.
:return: The loss on the test data.

Просмотреть файл

@ -47,11 +47,12 @@ def get_fastmri_data_module(azure_dataset_id: str,
Creates a LightningDataModule that consumes data from the FastMRI challenge. The type of challenge
(single/multicoil) is determined from the name of the dataset in Azure blob storage. The mask type is set to
equispaced, with 4x acceleration.
:param azure_dataset_id: The name of the dataset (folder name in blob storage).
:param local_dataset: The local folder at which the dataset has been mounted or downloaded.
:param sample_rate: Fraction of slices of the training data split to use. Set to a value <1.0 for rapid prototyping.
:param test_path: The name of the folder inside the dataset that contains the test data.
:return: A LightningDataModule object.
:return: The FastMRI LightningDataModule object.
"""
if not azure_dataset_id:
raise ValueError("The azure_dataset_id argument must be provided.")

Просмотреть файл

@ -47,17 +47,23 @@ class HeadAndNeckBase(SegmentationModelBase):
**kwargs: Any) -> None:
"""
Creates a new instance of the class.
:param ground_truth_ids: List of ground truth ids.
:param ground_truth_ids_display_names: Optional list of ground truth id display names. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param colours: Optional list of colours. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param fill_holes: Optional list of fill hole flags. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param roi_interpreted_types: Optional list of roi_interpreted_types. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param class_weights: Optional list of class weights. If
present then must be of the same length as ground_truth_ids + 1.
present then must be of the same length as ground_truth_ids + 1.
:param slice_exclusion_rules: Optional list of SliceExclusionRules.
:param summed_probability_rules: Optional list of SummedProbabilityRule.
:param num_feature_channels: Optional number of feature channels.

Просмотреть файл

@ -29,6 +29,7 @@ class HeadAndNeckPaper(HeadAndNeckBase):
def __init__(self, num_structures: Optional[int] = None, **kwargs: Any) -> None:
"""
Creates a new instance of the class.
:param num_structures: number of structures from STRUCTURE_LIST to predict (default: all structures)
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
"""

Просмотреть файл

@ -94,6 +94,7 @@ class HelloWorld(SegmentationModelBase):
https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-tune-hyperparameters
A reference is provided at https://docs.microsoft.com/en-us/python/api/azureml-train-core/azureml.train
.hyperdrive?view=azure-ml-py
:param run_config: The configuration for running an individual experiment.
:return: An Azure HyperDrive run configuration (configured PyTorch environment).
"""

Просмотреть файл

@ -28,17 +28,23 @@ class ProstateBase(SegmentationModelBase):
**kwargs: Any) -> None:
"""
Creates a new instance of the class.
:param ground_truth_ids: List of ground truth ids.
:param ground_truth_ids_display_names: Optional list of ground truth id display names. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param colours: Optional list of colours. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param fill_holes: Optional list of fill hole flags. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param interpreted_types: Optional list of interpreted_types. If
present then must be of the same length as ground_truth_ids.
present then must be of the same length as ground_truth_ids.
:param class_weights: Optional list of class weights. If
present then must be of the same length as ground_truth_ids + 1.
present then must be of the same length as ground_truth_ids + 1.
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
"""
ground_truth_ids_display_names = ground_truth_ids_display_names or [f"zz_{name}" for name in ground_truth_ids]

Просмотреть файл

@ -20,6 +20,7 @@ class ProstatePaper(ProstateBase):
def __init__(self, **kwargs: Any) -> None:
"""
Creates a new instance of the class.
:param kwargs: Additional arguments that will be passed through to the SegmentationModelBase constructor.
"""
ground_truth_ids = fg_classes

Просмотреть файл

@ -91,7 +91,7 @@ class PyTorchPassthroughModel(BaseSegmentationModel):
Ignore the actual patches and return a fixed segmentation, explained in make_nesting_rectangles.
:param patches: Set of patches, of shape (#patches, #image_channels, Z, Y, X). Only the shape
is used.
is used.
:return: Fixed tensor of shape (#patches, number_of_classes, Z, Y, Z).
"""
output_size: TupleInt3 = (patches.shape[2], patches.shape[3], patches.shape[4])

Просмотреть файл

@ -57,6 +57,7 @@ class CroppingDataset(FullImageDataset):
"""
Pad the original sample such the the provided images has the same
(or slightly larger in case of uneven difference) shape to the output_size, using the provided padding mode.
:param sample: Sample to pad.
:param crop_size: Crop size to match.
:param padding_mode: The padding scheme to apply.
@ -89,10 +90,12 @@ class CroppingDataset(FullImageDataset):
class_weights: Optional[List[float]] = None) -> CroppedSample:
"""
Creates an instance of a cropped sample extracted from full 3D images.
:param sample: the full size 3D sample to use for extracting a cropped sample.
:param crop_size: the size of the crop to extract.
:param center_size: the size of the center of the crop (this should be the same as the spatial dimensions
of the posteriors that the model produces)
:param class_weights: the distribution to use for the crop center class.
:return: CroppedSample
"""

Просмотреть файл

@ -31,6 +31,7 @@ def collate_with_metadata(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
The collate function that the dataloader workers should use. It does the same thing for all "normal" fields
(all fields are put into tensors with outer dimension batch_size), except for the special "metadata" field.
Those metadata objects are collated into a simple list.
:param batch: A list of samples that should be collated.
:return: collated result
"""
@ -84,7 +85,7 @@ class ImbalancedSampler(Sampler):
"""
:param dataset: a dataset
:num_samples: number of samples to draw. If None the number of samples
:num_samples: number of samples to draw. If None the number of samples
corresponds to the length of the dataset.
"""
self.dataset = dataset
@ -123,6 +124,7 @@ class RepeatDataLoader(DataLoader):
**kwargs: Any):
"""
Creates a new data loader.
:param dataset: The dataset that should be loaded.
:param batch_size: The number of samples per minibatch.
:param shuffle: If true, the dataset will be shuffled randomly.
@ -204,11 +206,12 @@ class FullImageDataset(GeneralDataset):
"""
Dataset class that loads and creates samples with full 3D images from a given pd.Dataframe. The following
are the operations performed to generate a sample from this dataset:
-------------------------------------------------------------------------------------------------
1) On initialization parses the provided pd.Dataframe with dataset information, to cache the set of file paths
and patient mappings to load as PatientDatasetSource. The sources are then saved in a list: dataset_sources.
2) dataset_sources is iterated in a batched fashion, where for each batch it loads the full 3D images, and applies
pre-processing functions (e.g. normalization), returning a sample that can be used for full image operations.
"""
def __init__(self, args: SegmentationModelBase, data_frame: pd.DataFrame,
@ -313,6 +316,7 @@ def load_dataset_sources(dataframe: pd.DataFrame,
The dataframe contains per-patient per-channel image information, relative to a root directory.
This method converts that into a per-patient dictionary, that contains absolute file paths
separated for for image channels, ground truth channels, and mask channels.
:param dataframe: A dataframe read directly from a dataset CSV file.
:param local_dataset_root_folder: The root folder that contains all images.
:param image_channels: The names of the image channels that should be used in the result.

Просмотреть файл

@ -40,6 +40,7 @@ class PatientMetadata:
For each of the columns "seriesId", "instituionId" and "tags", the distinct values for the given patient are
computed. If there is exactly 1 distinct value, that is returned as the respective patient metadata. If there is
more than 1 distinct value, the metadata column is set to None.
:param dataframe: The dataset to read from.
:param patient_id: The ID of the patient for which the metadata should be extracted.
:return: An instance of PatientMetadata for the given patient_id
@ -101,8 +102,9 @@ class SampleBase:
def from_dict(cls: Type[T], sample: Dict[str, Any]) -> T:
"""
Create an instance of the sample class, based on the provided sample dictionary
:param sample: dictionary of arguments
:return:
:return: an instance of the SampleBase class
"""
return cls(**sample) # type: ignore
@ -110,6 +112,7 @@ class SampleBase:
"""
Create a clone of the current sample, with the provided overrides to replace the
existing properties if they exist.
:param overrides:
:return:
"""
@ -118,7 +121,6 @@ class SampleBase:
def get_dict(self) -> Dict[str, Any]:
"""
Get the current sample as a dictionary of property names and their values.
:return:
"""
return vars(self)

Просмотреть файл

@ -33,7 +33,7 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
Converts a string from a dataset.csv file that contains a model's label to a scalar.
For classification datasets:
If num_classes is 1 (binary classification tasks):
If num_classes is 1 (binary classification tasks)
The function maps ["1", "true", "yes"] to [1], ["0", "false", "no"] to [0].
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
If num_classes is greater than 1 (multilabel datasets):
@ -42,17 +42,17 @@ def extract_label_classification(label_string: str, sample_id: str, num_classes:
map "1|3|4" to [0, 1, 0, 1, 1, 0]).
If the entry in the CSV file was missing (no string given at all) or an empty string,
this function returns an all-zero tensor (none of the label classes were positive for this sample).
For regression datasets:
The function casts a string label to float. Raises an exception if the conversion is
not possible.
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
The function casts a string label to float. Raises an exception if the conversion is
not possible.
If the entry in the CSV file was missing (no string given at all) or an empty string, it returns math.nan.
:param label_string: The value of the label as read from CSV via a DataFrame.
:param sample_id: The sample ID where this label was read from. This is only used for creating error messages.
:param num_classes: Number of classes. This should be equal the size of the model output.
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes should
correspond to the number of label classes in the problem.
For binary classification tasks, num_classes should be one. For multilabel classification tasks, num_classes
should correspond to the number of label classes in the problem.
:param is_classification_dataset: If the model is a classification model
:return: A list of floats with the same size as num_classes
"""
@ -120,9 +120,11 @@ def _get_single_channel_row(subject_rows: pd.DataFrame,
'channel' argument. Throws a ValueError if there is no or more than 1 such row.
The result is returned as a dictionary, not a DataFrame!
If the 'channel' argument is null, the input is expected to be already 1 row, which is returned as a dictionary.
:param subject_rows: A set of rows all belonging to the same subject.
:param channel: The value to look for in the `channel_column` column. This can be null. If it is null,
the input `subject_rows` is expected to have exactly 1 row.
the input `subject_rows` is expected to have exactly 1 row.
:param subject_id: A string describing the presently processed subject. This is only used for error reporting.
:return: A dictionary mapping from column names to values, created from the unique row that was found.
"""
@ -144,6 +146,7 @@ def _string_to_float(text: Union[str, float], error_message_prefix: str = None)
"""
Converts a string coming from a dataset.csv file to a floating point number, taking into account all the
corner cases that can happen when the dataset file is malformed.
:param text: The element coming from the dataset.csv file.
:param error_message_prefix: A prefix string that will go into the error message if the conversion fails.
:return: A floating point number, possibly np.nan.
@ -181,27 +184,32 @@ def load_single_data_source(subject_rows: pd.DataFrame,
"""
Converts a set of dataset rows for a single subject to a ScalarDataSource instance, which contains the
labels, the non-image features, and the paths to the image files.
:param num_classes: Number of classes, this is equivalent to model output tensor size
:param channel_column: The name of the column that contains the row identifier ("channels")
:param metadata_columns: A list of columns that well be added to the item metadata as key/value pairs.
:param subject_rows: All dataset rows that belong to the same subject.
:param subject_id: The identifier of the subject that is being processed.
:param image_channels: The names of all channels (stored in the CSV_CHANNEL_HEADER column of the dataframe)
that are expected to be loaded from disk later because they are large images.
that are expected to be loaded from disk later because they are large images.
:param image_file_column: The name of the column that contains the image file names.
:param label_channels: The name of the channel where the label scalar or vector is read from.
:param label_value_column: The column that contains the value for the label scalar or vector.
:param non_image_feature_channels: non_image_feature_channels: A dictonary of the names of all channels where
additional scalar values should be read from. THe keys should map each feature to its channels.
additional scalar values should be read from. THe keys should map each feature to its channels.
:param numerical_columns: The names of all columns where additional scalar values should be read from.
:param categorical_data_encoder: Encoding scheme for categorical data.
:param is_classification_dataset: If True, the dataset will be used in a classification model. If False,
assume that the dataset will be used in a regression model.
assume that the dataset will be used in a regression model.
:param transform_labels: a label transformation or a list of label transformation to apply to the labels.
If a list is provided, the transformations are applied in order from left to right.
If a list is provided, the transformations are applied in order from left to right.
:param sequence_position_numeric: Numeric position of the data source in a data sequence. Assumed to be
a non-sequential dataset item if None provided (default).
:return:
a non-sequential dataset item if None provided (default).
:return: A ScalarDataSource containing the specified data.
"""
def _get_row_for_channel(channel: Optional[str]) -> Dict[str, str]:
@ -234,6 +242,7 @@ def load_single_data_source(subject_rows: pd.DataFrame,
"""
Return either the list of channels for a given column or if None was passed as
numerical channels i.e. there are no channel to be specified return [None].
:param non_image_channels: Dict mapping features name to their channels
:param feature: feature name for which to return the channels
:return: List of channels for the given feature.
@ -351,15 +360,19 @@ class DataSourceReader():
:param image_channels: The names of all channels (stored in the CSV_CHANNEL_HEADER column of the dataframe)
:param label_channels: The name of the channel where the label scalar or vector is read from.
:param transform_labels: a label transformation or a list of label transformation to apply to the labels.
If a list is provided, the transformations are applied in order from left to right.
If a list is provided, the transformations are applied in order from left to right.
:param non_image_feature_channels: non_image_feature_channels: A dictionary of the names of all channels where
additional scalar values should be read from. The keys should map each feature to its channels.
additional scalar values should be read from. The keys should map each feature to its channels.
:param numerical_columns: The names of all columns where additional scalar values should be read from.
:param sequence_column: The name of the column that contains the sequence index, that will be stored in
metadata.sequence_position. If this column name is not provided, the sequence_position will be 0.
metadata.sequence_position. If this column name is not provided, the sequence_position will be 0.
:param subject_column: The name of the column that contains the subject identifier
:param channel_column: The name of the column that contains the row identifier ("channels")
that are expected to be loaded from disk later because they are large images.
that are expected to be loaded from disk later because they are large images.
:param is_classification_dataset: If the current dataset is classification or not.
:param categorical_data_encoder: Encoding scheme for categorical data.
"""
@ -422,6 +435,7 @@ class DataSourceReader():
"""
Loads dataset items from the given dataframe, where all column and channel configurations are taken from their
respective model config elements.
:param data_frame: The dataframe to read dataset items from.
:param args: The model configuration object.
:return: A list of all dataset items that could be read from the dataframe.
@ -452,13 +466,13 @@ class DataSourceReader():
def load_data_sources(self, num_dataset_reader_workers: int = 0) -> List[ScalarDataSource]:
"""
Extracts information from a dataframe to create a list of ClassificationItem. This will create one entry per
unique
value of subject_id in the dataframe. The file is structured around "channels", indicated by specific values in
the CSV_CHANNEL_HEADER column. The result contains paths to image files, a label vector, and a matrix of
additional values that are specified by rows and columns given in non_image_feature_channels and
unique value of subject_id in the dataframe. The file is structured around "channels", indicated by specific
values in the CSV_CHANNEL_HEADER column. The result contains paths to image files, a label vector, and a matrix
of additional values that are specified by rows and columns given in non_image_feature_channels and
numerical_columns.
:param num_dataset_reader_workers: Number of worker processes to use, if 0 then single threaded execution,
otherwise if -1 then multiprocessing with all available cpus will be used.
otherwise if -1 then multiprocessing with all available cpus will be used.
:return: A list of ScalarDataSource or SequenceDataSource instances
"""
subject_ids = self.data_frame[self.subject_column].unique()
@ -512,9 +526,9 @@ def files_by_stem(root_path: Path) -> Dict[str, Path]:
"""
Lists all files under the given root directory recursively, and returns a mapping from file name stem to full path.
The file name stem is computed more restrictively than what Path.stem returns: file.nii.gz will use "file" as the
stem, not "file.nii" as Path.stem would.
Only actual files are returned in the mapping, no directories.
If there are multiple files that map to the same stem, the function raises a ValueError.
stem, not "file.nii" as Path.stem would. Only actual files are returned in the mapping, no directories. If there are
multiple files that map to the same stem, the function raises a ValueError.
:param root_path: The root directory from which the file search should start.
:return: A dictionary mapping from file name stem to the full path to where the file is found.
"""
@ -546,11 +560,13 @@ def is_valid_item_index(item: ScalarDataSource,
min_sequence_position_value: int = 0) -> bool:
"""
Returns True if the item metadata in metadata.sequence_position is a valid sequence index.
:param item: The item to check.
:param min_sequence_position_value: Check if the item has a metadata.sequence_position that is at least
the value given here. Default is 0.
the value given here. Default is 0.
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
end with. Longer sequences will be truncated. None is default.
end with. Longer sequences will be truncated. None is default.
:return: True if the item has a valid index.
"""
# If no max_sequence_position_value is given, we don't care about
@ -572,9 +588,11 @@ def filter_valid_classification_data_sources_items(items: Iterable[ScalarDataSou
:param items: The list of items to filter.
:param min_sequence_position_value: Restrict the data to items with a metadata.sequence_position that is at least
the value given here. Default is 0.
the value given here. Default is 0.
:param max_sequence_position_value: If provided then this is the maximum sequence position the sequence can
end with. Longer sequences will be truncated. None is default.
end with. Longer sequences will be truncated. None is default.
:param file_to_path_mapping: A mapping from a file name stem (without extension) to its full path.
:return: A list of items, all of which are valid now.
"""
@ -637,9 +655,10 @@ class ScalarItemAugmentation:
segmentation_transform: Optional[Callable] = None) -> None:
"""
:param image_transform: transformation function to apply to images field. If None, images field is unchanged by
call.
call.
:param segmentation_transform: transformation function to apply to segmentations field. If None segmentations
field is unchanged by call.
field is unchanged by call.
"""
self.image_transform = image_transform
self.segmentation_transform = segmentation_transform
@ -671,7 +690,8 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
name: Optional[str] = None,
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
High level class for the scalar dataset designed to be inherited for specific behaviour
High level class for the scalar dataset designed to be inherited for specific behaviour.
:param args: The model configuration object.
:param data_frame: The dataframe to read from.
:param feature_statistics: If given, the normalization factor for the non-image features is taken
@ -691,7 +711,8 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
def load_all_data_sources(self) -> List[ScalarDataSource]:
"""
Uses the dataframe to create data sources to be used by the dataset.
:return:
:return: List of data sources.
"""
all_data_sources = DataSourceReader.load_data_sources_as_per_config(self.data_frame, self.args) # type: ignore
self.status += f"Loading: {self.create_status_string(all_data_sources)}"
@ -722,9 +743,10 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
"""
Loads the images and/or segmentations as given in the ClassificationDataSource item and
applying the optional transformation specified by the class.
:param item: The item to load.
:return: A ClassificationItem instances with the loaded images, and the labels and non-image features copied
from the argument.
from the argument.
"""
sample = item.load_images(
root_path=self.args.local_dataset,
@ -738,6 +760,7 @@ class ScalarDatasetBase(GeneralDataset[ScalarModelBase], ScalarDataSource):
def create_status_string(self, items: List[ScalarDataSource]) -> str:
"""
Creates a human readable string that contains the number of items, and the distinct number of subjects.
:param items: Use the items provided to create the string
:return: A string like "12 items for 5 subjects"
"""
@ -757,10 +780,12 @@ class ScalarDataset(ScalarDatasetBase):
sample_transform: Callable[[ScalarItem], ScalarItem] = ScalarItemAugmentation()):
"""
Creates a new scalar dataset from a dataframe.
:param args: The model configuration object.
:param data_frame: The dataframe to read from.
:param feature_statistics: If given, the normalization factor for the non-image features is taken
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
from the values provided. If None, the normalization factor is computed from the data in the present dataset.
:param sample_transform: Sample transforms that should be applied.
:param name: Name of the dataset, used for diagnostics logging
"""
@ -802,6 +827,7 @@ class ScalarDataset(ScalarDatasetBase):
one class index. The value stored will be the number of samples that belong to the positive class.
In the multilabel case, this returns a dictionary with class indices and samples per class as the key-value
pairs.
:return: Dictionary of {class_index: count}
"""
all_labels = [torch.flatten(torch.nonzero(item.label).int()).tolist() for item in self.items] # [N, 1]

Просмотреть файл

@ -39,7 +39,6 @@ class ScalarItemBase(SampleBase):
def id(self) -> str:
"""
Gets the identifier of the present object from metadata.
:return:
"""
return self.metadata.id # type: ignore
@ -47,7 +46,6 @@ class ScalarItemBase(SampleBase):
def props(self) -> Dict[str, Any]:
"""
Gets the general metadata dictionary for the present object.
:return:
"""
return self.metadata.props # type: ignore
@ -94,6 +92,7 @@ class ScalarItem(ScalarItemBase):
"""
Creates a copy of the present object where all tensors live on the given CUDA device.
The metadata field is left unchanged.
:param device: The CUDA or GPU device to move to.
:return: A new `ScalarItem` with all tensors on the chosen device.
"""
@ -124,16 +123,19 @@ class ScalarDataSource(ScalarItemBase):
root_path argument, or it must contain a file name stem only (without extension). In this case, the actual
mapping from file name stem to full path is expected in the file_mapping argument.
Either of 'root_path' or 'file_mapping' must be provided.
:param root_path: The root path where all channel files for images are expected. This is ignored if
file_mapping is given.
file_mapping is given.
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
:param load_segmentation: If True it loads segmentation if present on the same file as the image.
:param center_crop_size: If supplied, all loaded images will be cropped to the size given here. The crop will
be taken from the center of the image.
be taken from the center of the image.
:param image_size: If given, all loaded images will be reshaped to the size given here, prior to the
center crop.
center crop.
:return: An instance of ClassificationItem, with the same label and numerical_non_image_features fields,
and all images loaded.
and all images loaded.
"""
full_channel_files = self.get_all_image_filepaths(root_path=root_path,
file_mapping=file_mapping)
@ -156,8 +158,9 @@ class ScalarDataSource(ScalarItemBase):
"""
Checks if all file paths and non-image features are present in the object. All image channel files must
be not None, and none of the non imaging features may be NaN or infinity.
:return: True if channel files is a list with not-None entries, and all non imaging features are finite
floating point numbers.
floating point numbers.
"""
return self.files_valid() and super().is_valid()
@ -169,8 +172,10 @@ class ScalarDataSource(ScalarItemBase):
file_mapping: Optional[Dict[str, Path]]) -> List[Path]:
"""
Get a list of image paths for the object. Either root_path or file_mapping must be specified.
:param root_path: The root path where all channel files for images are expected. This is ignored if
file_mapping is given.
file_mapping is given.
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
"""
full_channel_files: List[Path] = []
@ -188,9 +193,11 @@ class ScalarDataSource(ScalarItemBase):
"""
Get the full path of an image file given the path relative to the dataset folder and one of
root_path or file_mapping.
:param file: Image filepath relative to the dataset folder
:param root_path: The root path where all channel files for images are expected. This is ignored if
file_mapping is given.
file_mapping is given.
:param file_mapping: A mapping from a file name stem (without extension) to its full path.
"""
if file is None:

Просмотреть файл

@ -117,14 +117,15 @@ class DeepLearningFileSystemConfig(Parameterized):
Creates a new object that holds output folder configurations. When running inside of AzureML, the output
folders will be directly under the project root. If not running inside AzureML, a folder with a timestamp
will be created for all outputs and logs.
:param project_root: The root folder that contains the code that submitted the present training run.
When running inside the InnerEye repository, it is the git repo root. When consuming InnerEye as a package,
this should be the root of the source code that calls the package.
When running inside the InnerEye repository, it is the git repo root. When consuming InnerEye as a package,
this should be the root of the source code that calls the package.
:param is_offline_run: If true, this is a run outside AzureML. If False, it is inside AzureML.
:param model_name: The name of the model that is trained. This is used to generate a run-specific output
folder.
folder.
:param output_to: If provided, the output folders will be created as a subfolder of this argument. If not
given, the output folders will be created inside of the project root.
given, the output folders will be created inside of the project root.
"""
if not project_root.is_absolute():
raise ValueError(f"The project root is required to be an absolute path, but got {project_root}")
@ -165,6 +166,7 @@ class DeepLearningFileSystemConfig(Parameterized):
"""
Creates a new output folder configuration, where both outputs and logs go into the given subfolder inside
the present outputs folder.
:param subfolder: The subfolder that should be created.
:return:
"""
@ -281,6 +283,7 @@ class WorkflowParams(param.Parameterized):
data_split: ModelExecutionMode) -> bool:
"""
Returns True if inference is required for this model_proc (single or ensemble) and data_split (Train/Val/Test).
:param model_proc: Whether we are testing an ensemble or single model.
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:return: True if inference required.
@ -442,6 +445,7 @@ class OutputParams(param.Parameterized):
def set_output_to(self, output_to: PathOrString) -> None:
"""
Adjusts the file system settings in the present object such that all outputs are written to the given folder.
:param output_to: The absolute path to a folder that should contain the outputs.
"""
if isinstance(output_to, Path):
@ -453,6 +457,7 @@ class OutputParams(param.Parameterized):
"""
Creates new file system settings (outputs folder, logs folder) based on the information stored in the
present object. If any of the folders do not yet exist, they are created.
:param project_root: The root folder for the codebase that triggers the training run.
"""
self.file_system_config = DeepLearningFileSystemConfig.create(
@ -783,6 +788,7 @@ class DeepLearningConfig(WorkflowParams,
def dataset_data_frame(self, data_frame: Optional[DataFrame]) -> None:
"""
Sets the pandas data frame that the model uses.
:param data_frame: The data frame to set.
"""
self._dataset_data_frame = data_frame
@ -844,12 +850,13 @@ class DeepLearningConfig(WorkflowParams,
See https://pytorch.org/tutorials/beginner/saving_loading_models.html#warmstarting-model-using-parameters
-from-a-different-model
for an explanation on why strict=False is useful when loading parameters from other models.
:param path_to_checkpoint: Path to the checkpoint file.
:return: Dictionary with model and optimizer state dicts. The dict should have at least the following keys:
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
1. Key ModelAndInfo.MODEL_STATE_DICT_KEY and value set to the model state dict.
2. Key ModelAndInfo.EPOCH_KEY and value set to the checkpoint epoch.
Other (optional) entries corresponding to keys ModelAndInfo.OPTIMIZER_STATE_DICT_KEY and
ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY are also supported.
"""
return load_checkpoint(path_to_checkpoint=path_to_checkpoint, use_gpu=self.use_gpu)

Просмотреть файл

@ -275,6 +275,7 @@ class InnerEyeLightning(LightningModule):
def validation_epoch_end(self, outputs: List[Any]) -> None:
"""
Resets the random number generator state to what it was before the current validation epoch started.
:param outputs: The list of outputs from the individual validation minibatches.
"""
# reset the random state for training, so that we get continue from where we were before the validation step.
@ -315,12 +316,13 @@ class InnerEyeLightning(LightningModule):
floating point, it is converted to a Tensor on the current device to enable synchronization.
:param sync_dist_override: If not None, use this value for the sync_dist argument to self.log. If None,
set it automatically depending on the use of DDP.
set it automatically depending on the use of DDP.
:param name: The name of the metric to log
:param value: The value of the metric. This can be a tensor, floating point value, or a Metric class.
:param is_training: If true, give the metric a "train/" prefix, otherwise a "val/" prefix.
:param reduce_fx: The reduce function to use when synchronizing the tensors across GPUs. This must be
a value recognized by sync_ddp: "sum", "mean"
a value recognized by sync_ddp: "sum", "mean"
"""
metric_name = name if isinstance(name, str) else name.value
prefix = TRAIN_PREFIX if is_training else VALIDATION_PREFIX
@ -334,10 +336,11 @@ class InnerEyeLightning(LightningModule):
"""
Stores a set of metrics (key/value pairs) to a file logger. That file logger is either one that only holds
training or only holds validation metrics.
:param metrics: A dictionary with all the metrics to write, as key/value pairs.
:param epoch: The epoch to which the metrics belong.
:param is_training: If true, write the metrics to the logger for training metrics, if False, write to the logger
for validation metrics.
for validation metrics.
"""
file_logger = self.train_epoch_metrics_logger if is_training else self.val_epoch_metrics_logger
store_epoch_metrics(metrics, epoch, file_logger=file_logger)
@ -346,6 +349,7 @@ class InnerEyeLightning(LightningModule):
"""
This hook is called when loading a model from a checkpoint. It just prints out diagnostics about which epoch
created the present checkpoint.
:param checkpoint: The checkpoint dictionary loaded from disk.
"""
keys = ['epoch', 'global_step']
@ -371,10 +375,11 @@ class InnerEyeLightning(LightningModule):
"""
This is the shared method that handles the training (when `is_training==True`) and validation steps
(when `is_training==False`)
:param sample: The minibatch of data that should be processed.
:param batch_index: The index of the current minibatch.
:param is_training: If true, this has been called from `training_step`, otherwise it has been called from
`validation_step`.
`validation_step`.
"""
raise NotImplementedError("This method must be overwritten in a derived class.")
@ -382,9 +387,10 @@ class InnerEyeLightning(LightningModule):
"""
Writes the given loss value to Lightning, labelled either "val/loss" or "train/loss".
If this comes from a training step, then also log the learning rate.
:param loss: The loss value that should be logged.
:param is_training: If True, the logged metric will be called "train/Loss". If False, the metric will
be called "val/Loss"
be called "val/Loss"
"""
assert isinstance(self.trainer, Trainer)
self.log_on_epoch(MetricType.LOSS, loss, is_training)

Просмотреть файл

@ -31,16 +31,18 @@ class InnerEyeInference(abc.ABC):
form of inference is slightly different from what PyTorch Lightning does in its `Trainer.test` method. In
particular, this inference can be executed on any of the training, validation, or test set.
The inference code calls the methods in this order:
The inference code calls the methods in this order::
model.on_inference_start()
for dataset_split in [Train, Val, Test]
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
for batch_idx, item in enumerate(dataloader[dataset_split])):
model_outputs = model.forward(item)
model.inference_step(item, batch_idx, model_outputs)
model.on_inference_epoch_end()
model.on_inference_end()
model.on_inference_start()
for dataset_split in [Train, Val, Test]
model.on_inference_epoch_start(dataset_split, is_ensemble_model=False)
for batch_idx, item in enumerate(dataloader[dataset_split])):
model_outputs = model.forward(item)
model.inference_step(item, batch_idx, model_outputs)
model.on_inference_epoch_end()
model.on_inference_end()
"""
def on_inference_start(self) -> None:
@ -55,9 +57,10 @@ class InnerEyeInference(abc.ABC):
Runs initialization for inference, when starting inference on a new dataset split (train/val/test).
Depending on the settings, this can be called anywhere between 0 (no inference at all) to 3 times (inference
on all of train/val/test split).
:param dataset_split: Indicates whether the item comes from the training, validation or test set.
:param is_ensemble_model: If False, the model_outputs come from an individual model. If True, the model
outputs come from multiple models.
outputs come from multiple models.
"""
pass
@ -65,6 +68,7 @@ class InnerEyeInference(abc.ABC):
"""
This hook is called when the model has finished making a prediction. It can write the results to a file,
or compute metrics and store them.
:param batch: The batch of data for which the model made a prediction.
:param model_output: The model outputs. This would usually be a torch.Tensor, but can be any datatype.
"""
@ -91,6 +95,7 @@ class InnerEyeInference(abc.ABC):
"""
Aggregates the outputs of multiple models when using an ensemble model. In the default implementation,
this averages the tensors coming from all the models.
:param model_outputs: An iterator over the model outputs for all ensemble members.
:return: The aggregate model outputs.
"""
@ -183,6 +188,7 @@ class LightningContainer(GenericConfig,
Because the method deals with data loaders, not loaded data, we cannot check automatically that cross validation
is handled correctly within the base class, i.e. if the cross validation split is not handled in the method then
nothing will fail, but each child run will be identical since they will each be given the full dataset.
:return: A LightningDataModule
"""
return None # type: ignore
@ -229,6 +235,7 @@ class LightningContainer(GenericConfig,
This can be avoided by always using unique parameter names.
Also note that saving a reference to `azure_config` and updating its attributes at any other
point may lead to unexpected behaviour.
:param azure_config: The initialised AzureConfig whose parameters to override in-place.
"""
pass
@ -297,6 +304,7 @@ class LightningContainer(GenericConfig,
Because this adds a val/Loss metric it is important that when subclassing LightningContainer
your implementation of LightningModule logs val/Loss. There is an example of this in
HelloRegression's validation_step method.
:param run_config: The AzureML run configuration object that training for an individual model.
:return: A hyperdrive configuration object.
"""
@ -315,6 +323,7 @@ class LightningContainer(GenericConfig,
"""
Returns the HyperDrive config for either parameter search or cross validation
(if number_of_cross_validation_splits > 1).
:param run_config: AzureML estimator
:return: HyperDriveConfigs
"""

Просмотреть файл

@ -17,6 +17,7 @@ def load_from_lightning_checkpoint(config: ModelConfigBase, checkpoint_path: Pat
"""
Reads a PyTorch model from a checkpoint. First, a PyTorch Lightning model is created matching the InnerEye
model configuration, its parameter tensors are then populated from the given checkpoint.
:param config: An InnerEye model configuration object
:param checkpoint_path: The location of the checkpoint file.
:return: A PyTorch Lightning model object.
@ -37,6 +38,7 @@ def adjust_model_for_inference(config: ModelConfigBase, lightning_model: InnerEy
Makes all necessary adjustments to use a given model for inference, possibly on multiple GPUs via
model parallelization. The method also computes parameters like output patch size for segmentation model,
and stores them in the model configuration.
:param config: The model configuration object. It may be modified in place.
:param lightning_model: The trained model that should be adjusted.
"""
@ -65,6 +67,7 @@ def load_from_checkpoint_and_adjust_for_inference(config: ModelConfigBase, check
"""
Reads a PyTorch model from a checkpoint, and makes all necessary adjustments to use the model for inference,
possibly on multiple GPUs.
:param config: An InnerEye model configuration object
:param checkpoint_path: The location of the checkpoint file.
:return: A PyTorch Lightning model object.

Просмотреть файл

@ -80,9 +80,10 @@ class StoringLogger(LightningLoggerBase):
Reads the set of metrics for a given epoch, filters them to retain only those that have the given prefix,
and returns the filtered ones. This is used to break a set
of results down into those for training data (prefix "Train/") or validation data (prefix "Val/").
:param epoch: The epoch for which results should be read.
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
have a name starting with `prefix`, and strip off the prefix.
have a name starting with `prefix`, and strip off the prefix.
:return: A metrics dictionary.
"""
epoch_results = self.results_per_epoch.get(epoch, None)
@ -103,8 +104,9 @@ class StoringLogger(LightningLoggerBase):
Converts the results stored in the present object into a two-level dictionary, mapping from epoch number to
metric name to metric value. Only metrics where the name starts with the given prefix are retained, and the
prefix is stripped off in the result.
:param prefix_filter: If empty string, return all metrics. If not empty, return only those metrics that
have a name starting with `prefix`, and strip off the prefix.
have a name starting with `prefix`, and strip off the prefix.
:return: A dictionary mapping from epoch number to metric name to metric value.
"""
return {epoch: self.extract_by_prefix(epoch, prefix_filter) for epoch in self.epochs}
@ -113,8 +115,10 @@ class StoringLogger(LightningLoggerBase):
"""
Gets a scalar metric out of either the list of training or the list of validation results. This returns
the value that a specific metric attains in all of the epochs.
:param is_training: If True, read metrics that have a "train/" prefix, otherwise those that have a "val/"
prefix.
prefix.
:param metric_type: The metric to extract.
:return: A list of floating point numbers, with one entry per entry in the the training or validation results.
"""
@ -132,6 +136,7 @@ class StoringLogger(LightningLoggerBase):
"""
Gets a scalar metric from the list of training results. This returns
the value that a specific metric attains in all of the epochs.
:param metric_type: The metric to extract.
:return: A list of floating point numbers, with one entry per entry in the the training results.
"""
@ -141,6 +146,7 @@ class StoringLogger(LightningLoggerBase):
"""
Gets a scalar metric from the list of validation results. This returns
the value that a specific metric attains in all of the epochs.
:param metric_type: The metric to extract.
:return: A list of floating point numbers, with one entry per entry in the the validation results.
"""

Просмотреть файл

@ -20,6 +20,7 @@ def nanmean(values: torch.Tensor) -> torch.Tensor:
"""
Computes the average of all values in the tensor, skipping those entries that are NaN (not a number).
If all values are NaN, the result is also NaN.
:param values: The values to average.
:return: A scalar tensor containing the average.
"""
@ -164,7 +165,7 @@ class ScalarMetricsBase(Metric):
difference between true positive rate and false positive rate is smallest. Then, computes
the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
:returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
:return: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
"""
preds, targets = self._get_preds_and_targets()
if torch.unique(targets).numel() == 1:
@ -287,12 +288,14 @@ class MetricForMultipleStructures(torch.nn.Module):
use_average_across_structures: bool = True) -> None:
"""
Creates a new MetricForMultipleStructures object.
:param ground_truth_ids: The list of anatomical structures that should be stored.
:param metric_name: The name of the metric that should be stored. This is used in the names of the individual
metrics.
metrics.
:param is_training: If true, use "train/" as the prefix for all metric names, otherwise "val/"
:param use_average_across_structures: If True, keep track of the average metric value across structures,
while skipping NaNs. If false, only store the per-structure metric values.
while skipping NaNs. If false, only store the per-structure metric values.
"""
super().__init__()
prefix = (TRAIN_PREFIX if is_training else VALIDATION_PREFIX) + metric_name + "/"
@ -307,6 +310,7 @@ class MetricForMultipleStructures(torch.nn.Module):
"""
Stores a vector of per-structure Dice scores in the present object. It updates the per-structure values,
and the aggregate value across all structures.
:param values_per_structure: A row tensor that has as many entries as there are ground truth IDs.
"""
if values_per_structure.dim() != 1 or values_per_structure.numel() != self.count:

Просмотреть файл

@ -51,6 +51,7 @@ class SegmentationLightning(InnerEyeLightning):
"""
Runs a set of 3D crops through the segmentation model, and returns the result. This method is used
at inference time.
:param patches: A tensor of size [batches, channels, Z, Y, X]
"""
return self.logits_to_posterior(self.model(patches))
@ -67,8 +68,10 @@ class SegmentationLightning(InnerEyeLightning):
is_training: bool) -> torch.Tensor:
"""
Runs training for a single minibatch of training or validation data, and computes all metrics.
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
`validation_step`.
`validation_step`.
:param sample: The batched sample on which the model should be trained.
:param batch_index: The index of the present batch (supplied only for diagnostics).
"""
@ -103,10 +106,11 @@ class SegmentationLightning(InnerEyeLightning):
is_training: bool) -> None:
"""
Computes and stores all metrics coming out of a single training step.
:param cropped_sample: The batched image crops used for training or validation.
:param segmentation: The segmentation that was produced by the model.
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
`validation_step`.
`validation_step`.
"""
# dice_per_crop_and_class has one row per crop, with background class removed
# Dice NaN means that both ground truth and prediction are empty.
@ -167,6 +171,7 @@ class SegmentationLightning(InnerEyeLightning):
def get_subject_output_file_per_rank(rank: int) -> str:
"""
Gets the name of a file that will store the per-rank per-subject model outputs.
:param rank: The rank of the current model in distributed training.
:return: A string like "rank7_metrics.csv"
"""
@ -228,11 +233,13 @@ class ScalarLightning(InnerEyeLightning):
is_training: bool) -> torch.Tensor:
"""
Runs training for a single minibatch of training or validation data, and computes all metrics.
:param is_training: If true, the method is called from `training_step`, otherwise it is called from
`validation_step`.
`validation_step`.
:param sample: The batched sample on which the model should be trained.
:param batch_index: The index of the present batch (supplied only for diagnostics).
Runs a minibatch of training or validation data through the model.
Runs a minibatch of training or validation data through the model.
"""
model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model, sample)
labels = model_inputs_and_labels.labels
@ -283,6 +290,7 @@ class ScalarLightning(InnerEyeLightning):
"""
For sequence models, transfer the nested lists of items to the given GPU device.
For all other models, this relies on the superclass to move the batch of data to the GPU.
:param batch: A batch of data coming from the dataloader.
:param device: The target CUDA device.
:return: A modified batch of data, where all tensor now live on the given CUDA device.
@ -294,6 +302,7 @@ def transfer_batch_to_device(batch: Any, device: torch.device) -> Any:
"""
For sequence models, transfer the nested lists of items to the given GPU device.
For all other models, this relies on Lightning's default code to move the batch of data to the GPU.
:param batch: A batch of data coming from the dataloader.
:param device: The target CUDA device.
:return: A modified batch of data, where all tensor now live on the given CUDA device.
@ -314,8 +323,10 @@ def create_lightning_model(config: ModelConfigBase, set_optimizer_and_scheduler:
"""
Creates a PyTorch Lightning model that matches the provided InnerEye model configuration object.
The `optimizer` and `l_rate_scheduler` object of the Lightning model will also be populated.
:param set_optimizer_and_scheduler: If True (default), initialize the optimizer and LR scheduler of the model.
If False, skip that step (this is only meant to be used for unit tests.)
If False, skip that step (this is only meant to be used for unit tests.)
:param config: An InnerEye model configuration object
:return: A PyTorch Lightning model object.
"""

Просмотреть файл

@ -65,6 +65,7 @@ class InferenceMetricsForSegmentation(InferenceMetrics):
def log_metrics(self, run_context: Run = None) -> None:
"""
Log metrics for each epoch to the provided runs logs, or the current run context if None provided
:param run_context: Run for which to log the metrics to, use the current run context if None provided
:return:
"""
@ -79,6 +80,7 @@ def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> flo
"""
Symmetric surface distances taking into account the image spacing
https://github.com/InsightSoftwareConsortium/SimpleITK-Notebooks/blob/master/Python/34_Segmentation_Evaluation.ipynb
:param seg: mask 1
:param reference_segmentation: mask 2
:return: mean distance
@ -118,6 +120,7 @@ def surface_distance(seg: sitk.Image, reference_segmentation: sitk.Image) -> flo
def _add_zero_distances(num_segmented_surface_pixels: int, seg2ref_distance_map_arr: np.ndarray) -> List[float]:
"""
# Get all non-zero distances and then add zero distances if required.
:param num_segmented_surface_pixels:
:param seg2ref_distance_map_arr:
:return: list of distances, augmented with zeros.
@ -137,6 +140,7 @@ def calculate_metrics_per_class(segmentation: np.ndarray,
Returns a MetricsDict with metrics for each of the foreground
structures. Metrics are NaN if both ground truth and prediction are all zero for a class.
If first element of a ground truth image channel is NaN, the image is flagged as NaN and not use.
:param ground_truth_ids: The names of all foreground classes.
:param segmentation: predictions multi-value array with dimensions: [Z x Y x X]
:param ground_truth: ground truth binary array with dimensions: [C x Z x Y x X].
@ -217,12 +221,13 @@ def compute_dice_across_patches(segmentation: torch.Tensor,
allow_multiple_classes_for_each_pixel: bool = False) -> torch.Tensor:
"""
Computes the Dice scores for all classes across all patches in the arguments.
:param segmentation: Tensor containing class ids predicted by a model.
:param ground_truth: One-hot encoded torch tensor containing ground-truth label ids.
:param allow_multiple_classes_for_each_pixel: If set to False, ground-truth tensor has
to contain only one foreground label for each pixel.
:return A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for
all classes including the background class at index 0.
to contain only one foreground label for each pixel.
:return: A torch tensor of size (Patches, Classes) with the Dice scores. Dice scores are computed for
all classes including the background class at index 0.
"""
check_size_matches(segmentation, ground_truth, 4, 5, [0, -3, -2, -1],
arg1_name="segmentation", arg2_name="ground_truth")
@ -255,6 +260,7 @@ def store_epoch_metrics(metrics: DictStrFloat,
"""
Writes all metrics (apart from ones that measure run time) into a CSV file,
with an additional columns for epoch number.
:param file_logger: An instance of DataframeLogger, for logging results to csv.
:param epoch: The epoch corresponding to the results.
:param metrics: The metrics of the specified epoch, averaged along its batches.
@ -291,12 +297,13 @@ def compute_scalar_metrics(metrics_dict: ScalarMetricsDict,
of class 1. The label vector is expected to contain class indices 0 and 1 only.
Metrics for each model output channel will be isolated, and a non-default hue for each model output channel is
expected, and must exist in the provided metrics_dict. The Default hue is used for single model outputs.
:param metrics_dict: An object that holds all metrics. It will be updated in-place.
:param subject_ids: Subject ids for the model output and labels.
:param model_output: A tensor containing model outputs.
:param labels: A tensor containing class labels.
:param loss_type: The type of loss that the model uses. This is required to optionally convert 2-dim model output
to probabilities.
to probabilities.
"""
_model_output_channels = model_output.shape[1]
model_output_hues = metrics_dict.get_hue_names(include_default=len(metrics_dict.hues_without_default) == 0)
@ -346,6 +353,7 @@ def add_average_foreground_dice(metrics: MetricsDict) -> None:
"""
If the given metrics dictionary contains an entry for Dice score, and only one value for the Dice score per class,
then add an average Dice score for all foreground classes to the metrics dictionary (modified in place).
:param metrics: The object that holds metrics. The average Dice score will be written back into this object.
"""
all_dice = []

Просмотреть файл

@ -35,9 +35,10 @@ def average_metric_values(values: List[float], skip_nan_when_averaging: bool) ->
"""
Returns the average (arithmetic mean) of the values provided. If skip_nan_when_averaging is True, the mean
will be computed without any possible NaN values in the list.
:param values: The individual values that should be averaged.
:param skip_nan_when_averaging: If True, compute mean with any NaN values. If False, any NaN value present
in the argument will make the function return NaN.
in the argument will make the function return NaN.
:return: The average of the provided values. If the argument is an empty list, NaN will be returned.
"""
if skip_nan_when_averaging:
@ -61,6 +62,7 @@ def get_column_name_for_logging(metric_name: Union[str, MetricType],
"""
Computes the column name that should be used when logging a metric to disk.
Raises a value error when no column name has yet been defined.
:param metric_name: The name of the metric.
:param hue_name: If provided will be used as a prefix hue_name/column_name
"""
@ -104,11 +106,13 @@ class Hue:
labels: np.ndarray) -> None:
"""
Adds predictions and labels for later computing the area under the ROC curve.
:param subject_ids: Subject ids associated with the predictions and labels.
:param predictions: A numpy array with model predictions, of size [N x C] for N samples in C classes, or size
[N x 1] or size [N] for binary.
[N x 1] or size [N] for binary.
:param labels: A numpy array with labels, of size [N x C] for N samples in C classes, or size
[N x 1] or size [N] for binary.
[N x 1] or size [N] for binary.
"""
if predictions.ndim == 1:
predictions = np.expand_dims(predictions, axis=1)
@ -155,6 +159,7 @@ class Hue:
def _concat_if_needed(arrays: List[np.ndarray]) -> np.ndarray:
"""
Joins a list of arrays into a single array, taking empty lists into account correctly.
:param arrays: Array list to be concatenated.
"""
if arrays:
@ -192,7 +197,8 @@ class MetricsDict:
def __init__(self, hues: Optional[List[str]] = None, is_classification_metrics: bool = True) -> None:
"""
:param hues: Supported hues for this metrics dict, otherwise all records will belong to the
default hue.
default hue.
:param is_classification_metrics: If this is a classification metrics dict
"""
@ -210,14 +216,16 @@ class MetricsDict:
def subject_ids(self, hue: str = DEFAULT_HUE_KEY) -> List[str]:
"""
Return the subject ids that have metrics associated with them in this dictionary.
:param hue: If provided then subject ids belonging to this hue only will be returned.
Otherwise subject ids for the default hue will be returned.
Otherwise subject ids for the default hue will be returned.
"""
return self._get_hue(hue=hue).subject_ids
def get_hue_names(self, include_default: bool = True) -> List[str]:
"""
Returns all of the hues supported by this metrics dict
:param include_default: Include the default hue if True, otherwise exclude the default hue.
"""
_hue_names = list(self.hues.keys())
@ -228,6 +236,7 @@ class MetricsDict:
def delete_hue(self, hue: str) -> None:
"""
Removes all data stored for the given hue from the present object.
:param hue: The hue to remove.
"""
del self.hues[hue]
@ -236,6 +245,7 @@ class MetricsDict:
"""
Gets the value stored for the given metric. The method assumes that there is a single value stored for the
metric, and raises a ValueError if that is not the case.
:param metric_name: The name of the metric to retrieve.
:param hue: The hue to retrieve the metric from.
:return:
@ -249,6 +259,7 @@ class MetricsDict:
def has_prediction_entries(self, hue: str = DEFAULT_HUE_KEY) -> bool:
"""
Returns True if the present object stores any entries for computing the Area Under Roc Curve metric.
:param hue: will be used to check a particular hue otherwise default hue will be used.
:return: True if entries exist. False otherwise.
"""
@ -257,8 +268,9 @@ class MetricsDict:
def values(self, hue: str = DEFAULT_HUE_KEY) -> Dict[str, Any]:
"""
Returns values held currently in the dict
:param hue: will be used to restrict values for the provided hue otherwise values in the default
hue will be returned.
hue will be returned.
:return: Dictionary of values for this object.
"""
return self._get_hue(hue).values
@ -267,6 +279,7 @@ class MetricsDict:
"""
Adds a diagnostic value to the present object. Multiple diagnostics can be stored per unique value of name,
the values get concatenated.
:param name: The name of the diagnostic value to store.
:param value: The value to store.
"""
@ -292,10 +305,12 @@ class MetricsDict:
hue: str = DEFAULT_HUE_KEY) -> None:
"""
Adds values for a single metric to the present object, when the metric value is a scalar.
:param metric_name: The name of the metric to add. This can be a string or a value in the MetricType enum.
:param metric_value: The values of the metric, as a float or integer.
:param skip_nan_when_averaging: If True, averaging this metric will skip any NaN (not a number) values.
If False, NaN will propagate through the mean computation.
If False, NaN will propagate through the mean computation.
:param hue: The hue for which this record belongs to, default hue will be used if None provided.
"""
_metric_name = MetricsDict._metric_name(metric_name)
@ -315,6 +330,7 @@ class MetricsDict:
hue: str = DEFAULT_HUE_KEY) -> None:
"""
Deletes all values that are stored for a given metric from the present object.
:param metric_name: The name of the metric to add. This can be a string or a value in the MetricType enum.
:param hue: The hue for which this record belongs to, default hue will be used if None provided.
"""
@ -327,11 +343,14 @@ class MetricsDict:
hue: str = DEFAULT_HUE_KEY) -> None:
"""
Adds predictions and labels for later computing the area under the ROC curve.
:param subject_ids: Subject ids associated with the predictions and labels.
:param predictions: A numpy array with model predictions, of size [N x C] for N samples in C classes, or size
[N x 1] or size [N] for binary.
[N x 1] or size [N] for binary.
:param labels: A numpy array with labels, of size [N x C] for N samples in C classes, or size
[N x 1] or size [N] for binary.
[N x 1] or size [N] for binary.
:param hue: The hue this prediction belongs to, default hue will be used if None provided.
"""
self._get_hue(hue).add_predictions(subject_ids=subject_ids,
@ -341,6 +360,7 @@ class MetricsDict:
def num_entries(self, hue: str = DEFAULT_HUE_KEY) -> Dict[str, int]:
"""
Gets the number of values that are stored for each individual metric.
:param hue: The hue to count entries for, otherwise all entries will be counted.
:return: A dictionary mapping from metric name to number of values stored.
"""
@ -355,9 +375,10 @@ class MetricsDict:
object.
Computing the average will respect the skip_nan_when_averaging value that has been provided when adding
the metric.
:param add_metrics_from_entries: average existing metrics in the dict.
:param across_hues: If True then same metric types will be averaged regardless of hues, otherwise
separate averages for each metric type for each hue will be computed, Default is True.
separate averages for each metric type for each hue will be computed, Default is True.
:return: A MetricsDict object with a single-item list for each of the metrics.
"""
@ -434,8 +455,9 @@ class MetricsDict:
difference between true positive rate and false positive rate is smallest. Then, computes
the false positive rate, false negative rate and accuracy at this threshold (i.e. when the
predicted probability is higher than the threshold the predicted label is 1 otherwise 0).
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:returns: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
:return: Tuple(optimal_threshold, false positive rate, false negative rate, accuracy)
"""
fpr, tpr, thresholds = roc_curve(self.get_labels(hue=hue), self.get_predictions(hue=hue))
optimal_idx = MetricsDict.get_optimal_idx(fpr=fpr, tpr=tpr)
@ -450,6 +472,7 @@ class MetricsDict:
def get_roc_auc(self, hue: str = DEFAULT_HUE_KEY) -> float:
"""
Computes the Area Under the ROC curve, from the entries that were supplied in the add_roc_entries method.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: The AUC score, or np.nan if no entries are available in the present object.
"""
@ -469,6 +492,7 @@ class MetricsDict:
"""
Computes the Area Under the Precision Recall Curve, from the entries that were supplied in the
add_roc_entries method.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: The PR AUC score, or np.nan if no entries are available in the present object.
"""
@ -488,6 +512,7 @@ class MetricsDict:
"""
Computes the binary cross entropy from the entries that were supplied in the
add_roc_entries method.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: The cross entropy score.
"""
@ -498,6 +523,7 @@ class MetricsDict:
def get_mean_absolute_error(self, hue: str = DEFAULT_HUE_KEY) -> float:
"""
Get the mean absolute error.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: Mean absolute error.
"""
@ -506,6 +532,7 @@ class MetricsDict:
def get_mean_squared_error(self, hue: str = DEFAULT_HUE_KEY) -> float:
"""
Get the mean squared error.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: Mean squared error
"""
@ -514,6 +541,7 @@ class MetricsDict:
def get_r2_score(self, hue: str = DEFAULT_HUE_KEY) -> float:
"""
Get the R2 score.
:param hue: The hue to restrict the values used for computation, otherwise all values will be used.
:return: R2 score
"""
@ -524,6 +552,7 @@ class MetricsDict:
Returns an iterator that contains all (hue name, metric name, metric values) tuples that are stored in the
present object. This method assumes that for each hue/metric combination there is exactly 1 value, and it
throws an exception if that is more than 1 value.
:param hue: The hue to restrict the values, otherwise all values will be used if set to None.
:return: An iterator with (hue name, metric name, metric values) pairs.
"""
@ -536,6 +565,7 @@ class MetricsDict:
"""
Returns an iterator that contains all (hue name, metric name, metric values) tuples that are stored in the
present object.
:param hue: The hue to restrict the values, otherwise all values will be used if set to None.
:param ensure_singleton_values_only: Ensure that each of the values return is a singleton.
:return: An iterator with (hue name, metric name, metric values) pairs.
@ -566,6 +596,7 @@ class MetricsDict:
def get_predictions(self, hue: str = DEFAULT_HUE_KEY) -> np.ndarray:
"""
Return a concatenated copy of the roc predictions stored internally.
:param hue: The hue to restrict the values, otherwise all values will be used.
:return: concatenated roc predictions as np array
"""
@ -574,6 +605,7 @@ class MetricsDict:
def get_labels(self, hue: str = DEFAULT_HUE_KEY) -> np.ndarray:
"""
Return a concatenated copy of the roc labels stored internally.
:param hue: The hue to restrict the values, otherwise all values will be used.
:return: roc labels as np array
"""
@ -583,6 +615,7 @@ class MetricsDict:
-> List[PredictionEntry[float]]:
"""
Gets the per-subject labels and predictions that are stored in the present object.
:param hue: The hue to restrict the values, otherwise the default hue will be used.
:return: List of per-subject labels and predictions
"""
@ -591,6 +624,7 @@ class MetricsDict:
def to_string(self, tabulate: bool = True) -> str:
"""
Creates a multi-line human readable string from the given metrics.
:param tabulate: If True then create a pretty printable table string.
:return: Formatted metrics string
"""
@ -623,6 +657,7 @@ class MetricsDict:
"""
Get the hue record for the provided key.
Raises a KeyError if the provided hue key does not exist.
:param hue: The hue to retrieve record for
"""
if hue not in self.hues:
@ -654,6 +689,7 @@ class ScalarMetricsDict(MetricsDict):
cross_validation_split_index: int = DEFAULT_CROSS_VALIDATION_SPLIT_INDEX) -> None:
"""
Store metrics using the provided df_logger at subject level for classification models.
:param df_logger: A data frame logger to use to write the metrics to disk.
:param mode: Model execution mode these metrics belong to.
:param cross_validation_split_index: cross validation split index for the epoch if performing cross val
@ -677,8 +713,9 @@ class ScalarMetricsDict(MetricsDict):
"""
Helper function to create BinaryClassificationMetricsDict grouped by ModelExecutionMode and epoch
from a given dataframe. The following columns must exist in the provided data frame:
>>> LoggingColumns.DataSplit
>>> LoggingColumns.Epoch
* LoggingColumns.DataSplit
* LoggingColumns.Epoch
:param df: DataFrame to use for creating the metrics dict.
:param is_classification_metrics: If the current metrics are for classification or not.
@ -720,6 +757,7 @@ class ScalarMetricsDict(MetricsDict):
Given metrics dicts for execution modes and epochs, compute the aggregate metrics that are computed
from the per-subject predictions. The metrics are written to the dataframe logger with the string labels
(column names) taken from the `MetricType` enum.
:param metrics: Mapping between epoch and subject level metrics
:param data_frame_logger: DataFrame logger to write to and flush
:param log_info: If True then log results as an INFO string to the default logger also.
@ -781,6 +819,7 @@ class SequenceMetricsDict(ScalarMetricsDict):
"""
Extracts a sequence target index from a metrics hue name. For example, from metrics hue "Seq_pos 07",
it would return 7.
:param hue_name: hue name containing sequence target index
"""
if hue_name.startswith(SEQUENCE_POSITION_HUE_NAME_PREFIX):
@ -807,6 +846,7 @@ class DataframeLogger:
def flush(self, log_info: bool = False) -> None:
"""
Save the internal records to a csv file.
:param log_info: If true, write the final dataframe also to logging.info.
"""
import pandas as pd

Просмотреть файл

@ -55,6 +55,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
Returns a configuration for AzureML Hyperdrive that should be used when running hyperparameter
tuning.
This is an abstract method that each specific model should override.
:param run_config: The AzureML estimator object that runs model training.
:return: A hyperdrive configuration object.
"""
@ -66,6 +67,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
"""
Computes the training, validation and test splits for the model, from a dataframe that contains
the full dataset.
:param dataset_df: A dataframe that contains the full dataset that the model is using.
:return: An instance of DatasetSplits with dataframes for training, validation and testing.
"""
@ -83,6 +85,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
are False, the derived method *may* still create the corresponding datasets, but should not assume that
the relevant splits (train/test/val) are non-empty. If either or both is True, they *must* create the
corresponding datasets, and should be able to make the assumption.
:param for_training: whether to create the datasets required for training.
:param for_inference: whether to create the datasets required for inference.
"""
@ -103,6 +106,8 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
"""
Returns a torch Dataset for running the model in inference mode, on the given split of the full dataset.
The torch dataset must return data in the format required for running the model in inference mode.
:param mode: The mode of the model, either test, train or val.
:return: A torch Dataset object.
"""
if self._datasets_for_inference is None:
@ -114,7 +119,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
"""
Creates the torch DataLoaders that supply the training and the validation set during training only.
:return: A dictionary, with keys ModelExecutionMode.TRAIN and ModelExecutionMode.VAL, and their respective
data loaders.
data loaders.
"""
logging.info("Starting to read and parse the datasets.")
if self._datasets_for_training is None:
@ -161,6 +166,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
def get_cross_validation_hyperdrive_config(self, run_config: ScriptRunConfig) -> HyperDriveConfig:
"""
Returns a configuration for AzureML Hyperdrive that varies the cross validation split index.
:param run_config: The AzureML run configuration object that training for an individual model.
:return: A hyperdrive configuration object.
"""
@ -176,9 +182,10 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
"""
When running cross validation, this method returns the dataset split that should be used for the
currently executed cross validation split.
:param dataset_split: The full dataset, split into training, validation and test section.
:return: The dataset split with training and validation sections shuffled according to the current
cross validation index.
cross validation index.
"""
splits = dataset_split.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)
return splits[self.cross_validation_split_index]
@ -187,6 +194,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
"""
Returns the HyperDrive config for either parameter search or cross validation
(if number_of_cross_validation_splits > 1).
:param run_config: AzureML estimator
:return: HyperDriveConfigs
"""
@ -239,6 +247,7 @@ class ModelConfigBase(DeepLearningConfig, abc.ABC, metaclass=ModelConfigBaseMeta
A hook to adjust the model configuration that is stored in the present object to match
the torch model given in the argument. This hook is called after adjusting the model for
mixed precision and parallel training.
:param model: The torch model.
"""
pass
@ -267,7 +276,8 @@ class ModelTransformsPerExecutionMode:
"""
:param train: the transformation(s) to apply to the training set.
Should be a function that takes a sample as input and outputs sample.
Should be a function that takes a sample as input and outputs sample.
:param val: the transformation(s) to apply to the validation set
:param test: the transformation(s) to apply to the test set
"""

Просмотреть файл

@ -51,6 +51,7 @@ def model_test(config: ModelConfigBase,
Runs model inference on segmentation or classification models, using a given dataset (that could be training,
test or validation set). The inference results and metrics will be stored and logged in a way that may
differ for model categories (classification, segmentation).
:param config: The configuration of the model
:param data_split: Indicates which of the 3 sets (training, test, or validation) is being processed.
:param checkpoint_paths: Checkpoint paths to initialize model.
@ -82,6 +83,7 @@ def segmentation_model_test(config: SegmentationModelBase,
"""
The main testing loop for segmentation models.
It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
:param config: The arguments object which has a valid random seed attribute.
:param execution_mode: Indicates which of the 3 sets (training, test, or validation) is being processed.
:param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization.
@ -121,6 +123,7 @@ def segmentation_model_test_epoch(config: SegmentationModelBase,
The main testing loop for a given epoch. It loads the model and datasets, then proceeds to test the model.
Returns a list with an entry for each image in the dataset. The entry is the average Dice score,
where the average is taken across all non-background structures in the image.
:param checkpoint_paths: Checkpoint paths to run inference on.
:param config: The arguments which specify all required information.
:param execution_mode: Is the model evaluated on train, test, or validation set?
@ -128,7 +131,7 @@ def segmentation_model_test_epoch(config: SegmentationModelBase,
:param epoch_and_split: A string that should uniquely identify the epoch and the data split (train/val/test).
:raises TypeError: If the arguments are of the wrong type.
:raises ValueError: When there are issues loading the model.
:return A list with the mean dice score (across all structures apart from background) for each image.
:return: A list with the mean dice score (across all structures apart from background) for each image.
"""
ml_util.set_random_seed(config.get_effective_random_seed(), "Model testing")
results_folder.mkdir(exist_ok=True)
@ -204,11 +207,12 @@ def evaluate_model_predictions(process_id: int,
Evaluates model segmentation predictions, dice scores and surface distances are computed.
Generated contours are plotted and saved in results folder.
The function is intended to be used in parallel for loop to process each image in parallel.
:param process_id: Identifier for the process calling the function
:param config: Segmentation model config object
:param dataset: Dataset object, it is used to load intensity image, labels, and patient metadata.
:param results_folder: Path to results folder
:returns [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
:return: [PatientMetadata, list[list]]: Patient metadata and list of computed metrics for each image.
"""
sample = dataset.get_samples_at_index(index=process_id)[0]
logging.info(f"Evaluating predictions for patient {sample.patient_id}")
@ -235,10 +239,12 @@ def populate_metrics_writer(
config: SegmentationModelBase) -> Tuple[MetricsPerPatientWriter, List[FloatOrInt]]:
"""
Populate a MetricsPerPatientWriter with the metrics for each patient
:param model_prediction_evaluations: The list of PatientMetadata/MetricsDict tuples obtained
from evaluate_model_predictions
from evaluate_model_predictions
:param config: The SegmentationModelBase config from which we read the ground_truth_ids
:returns: A new MetricsPerPatientWriter and a list of foreground DICE score averages
:return: A new MetricsPerPatientWriter and a list of foreground DICE score averages
"""
average_dice: List[FloatOrInt] = []
metrics_writer = MetricsPerPatientWriter()
@ -263,6 +269,7 @@ def get_patient_results_folder(results_folder: Path, patient_id: int) -> Path:
"""
Gets a folder name that will contain all results for a given patient, like root/017 for patient 17.
The folder name is constructed such that string sorting gives numeric sorting.
:param results_folder: The root folder in which the per-patient results should sit.
:param patient_id: The numeric ID of the patient.
:return: A path like "root/017"
@ -276,8 +283,10 @@ def store_inference_results(inference_result: InferencePipeline.Result,
image_header: ImageHeader) -> List[Path]:
"""
Store the segmentation, posteriors, and binary predictions into Nifti files.
:param inference_result: The inference result for a given patient_id and epoch. Posteriors must be in
(Classes x Z x Y x X) shape, segmentation in (Z, Y, X)
(Classes x Z x Y x X) shape, segmentation in (Z, Y, X)
:param config: The test configurations.
:param results_folder: The folder where the prediction should be stored.
:param image_header: The image header that was used in the input image.
@ -286,6 +295,7 @@ def store_inference_results(inference_result: InferencePipeline.Result,
def create_file_path(_results_folder: Path, _file_name: str) -> Path:
"""
Create filename with Nifti extension
:param _results_folder: The results folder
:param _file_name: The name of the file
:return: A full path to the results folder for the file
@ -339,6 +349,7 @@ def store_run_information(results_folder: Path,
image_channels: List[str]) -> None:
"""
Store dataset id and ground truth ids into files in the results folder.
:param image_channels: The names of the image channels that the model consumes.
:param results_folder: The folder where the files should be stored.
:param dataset_id: The dataset id
@ -359,6 +370,7 @@ def create_inference_pipeline(config: ModelConfigBase,
"""
If multiple checkpoints are found in run_recovery then create EnsemblePipeline otherwise InferencePipeline.
If no checkpoint files exist in the run recovery or current run checkpoint folder, None will be returned.
:param config: Model related configs.
:param epoch: The epoch for which to create pipeline for.
:param run_recovery: RunRecovery data if applicable
@ -408,9 +420,11 @@ def classification_model_test(config: ScalarModelBase,
"""
The main testing loop for classification models. It runs a loop over all epochs for which testing should be done.
It loads the model and datasets, then proceeds to test the model for all requested checkpoints.
:param config: The model configuration.
:param data_split: The name of the folder to store the results inside each epoch folder in the outputs_dir,
used mainly in model evaluation using different dataset splits.
:param checkpoint_paths: Checkpoint paths to initialize model
:param model_proc: whether we are testing an ensemble or single model
:return: InferenceMetricsForClassification object that contains metrics related for all of the checkpoint epochs.

Просмотреть файл

@ -38,6 +38,7 @@ def upload_output_file_as_temp(file_path: Path, outputs_folder: Path) -> None:
"""
Uploads a file to the AzureML run. It will get a name that is composed of a "temp/" prefix, plus the path
of the file relative to the outputs folder that is used for training.
:param file_path: The path of the file to upload.
:param outputs_folder: The root folder that contains all training outputs.
"""
@ -65,6 +66,7 @@ def create_lightning_trainer(container: LightningContainer,
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
return value.
:param container: The container with model and data.
:param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
:param num_nodes: The number of nodes to use in distributed training.
@ -204,13 +206,14 @@ def model_train(checkpoint_path: Optional[Path],
The main training loop. It creates the Pytorch model based on the configuration options passed in,
creates a Pytorch Lightning trainer, and trains the model.
If a checkpoint was specified, then it loads the checkpoint before resuming training.
:param checkpoint_path: Checkpoint path for model initialization
:param num_nodes: The number of nodes to use in distributed training.
:param container: A container object that holds the training data in PyTorch Lightning format
and the model to train.
and the model to train.
:return: A tuple of [Trainer, StoringLogger]. Trainer is the Lightning Trainer object that was used for fitting
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
fitting other models.
the model. The StoringLogger object is returned when training an InnerEye built-in model, this is None when
fitting other models.
"""
lightning_model = container.model
@ -323,8 +326,9 @@ def model_train(checkpoint_path: Optional[Path],
def aggregate_and_create_subject_metrics_file(outputs_folder: Path) -> None:
"""
This functions takes all the subject metrics file written by each GPU (one file per GPU) and aggregates them into
one single metrics file. Results is saved in config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME.
one single metrics file. Results is saved in ``config.outputs_folder / mode.value / SUBJECT_METRICS_FILE_NAME``.
This is done for the metrics files for training and for validation data separately.
:param config: model config
"""
for mode in [ModelExecutionMode.TRAIN, ModelExecutionMode.VAL]:

Просмотреть файл

@ -23,8 +23,9 @@ class CropSizeConstraints:
"""
:param multiple_of: Stores minimum size and other conditions that a training crop size must satisfy.
:param minimum_size: Training crops must have a size that is a multiple of this value, along each dimension.
For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a
For example, if set to (1, 16, 16), the crop size has to be a multiple of 16 along X and Y, and a
multiple of 1 (i.e., any number) along the Z dimension.
:param num_dimensions: Training crops must have a size that is at least this value.
"""
self.multiple_of = multiple_of
@ -56,6 +57,7 @@ class CropSizeConstraints:
"""
Checks if the given crop size is a valid crop size for the present model.
If it is not valid, throw a ValueError.
:param crop_size: The crop size that should be checked.
:param message_prefix: A string prefix for the error message if the crop size is found to be invalid.
:return:
@ -83,6 +85,7 @@ class CropSizeConstraints:
(at test time). The new crop size will be the largest multiple of self.multiple_of that fits into the
image_shape.
The stride size will attempt to maintain the stride-to-crop ratio before adjustment.
:param image_shape: The shape of the image to process.
:param crop_size: The present test crop size.
:param stride_size: The present inference stride size.
@ -121,10 +124,11 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
):
"""
Creates a new instance of the base model class.
:param name: A human readable name of the model.
:param input_channels: The number of image input channels.
:param crop_size_constraints: The size constraints for the training crop size. If not provided,
a minimum crop size of 1 is assumed.
a minimum crop size of 1 is assumed.
"""
super().__init__()
self.num_dimensions = 3
@ -144,6 +148,7 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
The argument is expected to be either a 2-tuple or a 3-tuple. A batch dimension (1)
and the number of channels are added as the first dimensions. The result tuple has batch and channel dimension
stripped off.
:param input_shape: A tuple (2D or 3D) representing incoming tensor shape.
"""
# Create a sample tensor for inference
@ -166,6 +171,7 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
"""
Checks if the given crop size is a valid crop size for the present model.
If it is not valid, throw a ValueError.
:param crop_size: The crop size that should be checked.
:param message_prefix: A string prefix for the error message if the crop size is found to be invalid.
"""
@ -178,8 +184,9 @@ class BaseSegmentationModel(DeviceAwareModule, ABC):
Stores a model summary, containing information about layers, memory consumption and runtime
in the model.summary field.
When called again with the same crop_size, the summary is not created again.
:param crop_size: The crop size for which the summary should be created. If not provided,
the minimum allowed crop size is used.
the minimum allowed crop size is used.
:param log_summaries_to_files: whether to write the summary to a file
"""
if crop_size is None:

Просмотреть файл

@ -64,23 +64,29 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
"""
Creates an image classifier that has UNet encoders sections for each image channel. The encoder output
is fed through average pooling and an MLP.
:param encode_channels_jointly: If False, create a UNet encoder structure separately for each channel. If True,
encode all channels jointly (convolution will run over all channels).
encode all channels jointly (convolution will run over all channels).
:param num_encoder_blocks: Number of UNet encoder blocks.
:param initial_feature_channels: Number of feature channels in the first UNet encoder.
:param num_image_channels: Number of channels of the input. Input is expected to be of size
B x num_image_channels x Z x Y x X, where B is the batch dimension.
B x num_image_channels x Z x Y x X, where B is the batch dimension.
:param num_non_image_features: Number of non imaging features will be used in the model.
:param kernel_size_per_encoding_block: The size of the kernels per encoding block, assumed to be the same
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
performs convolutions only in X and Y.
:param stride_size_per_encoding_block: The stride size for the encoding block, assumed to be the same
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
reduces spatial dimensions only in X and Y.
:param encoder_dimensionality_reduction_factor: how to reduce the dimensionality of the image features in the
combined model to balance with non imaging features.
combined model to balance with non imaging features.
:param scan_size: should be a tuple representing 3D tensor shape and if specified it's usedd in initializing
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
"""
super().__init__()
self.num_non_image_features = num_non_image_features
@ -168,6 +174,7 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
def _get_aggregation_layer(self, aggregation_type: AggregationType, scan_size: Optional[TupleInt3]) -> Any:
"""
Returns the aggregation layer as specified by the config
:param aggregation_type: name of the aggregation
:param scan_size: [Z, Y, X] size of the scans
"""
@ -191,6 +198,7 @@ class ImageEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
"""
Transforms a classification item into a torch.Tensor that the forward pass can consume
:param item: ClassificationItem
:return: Tensor
"""
@ -290,25 +298,31 @@ class ImageEncoderWithMlp(ImageEncoder):
Creates an image classifier that has UNet encoders sections for each image channel. The encoder output
is fed through average pooling and an MLP. Extension of the ImageEncoder class using an MLP as classification
layer.
:param encode_channels_jointly: If False, create a UNet encoder structure separately for each channel. If True,
encode all channels jointly (convolution will run over all channels).
encode all channels jointly (convolution will run over all channels).
:param num_encoder_blocks: Number of UNet encoder blocks.
:param initial_feature_channels: Number of feature channels in the first UNet encoder.
:param num_image_channels: Number of channels of the input. Input is expected to be of size
B x num_image_channels x Z x Y x X, where B is the batch dimension.
B x num_image_channels x Z x Y x X, where B is the batch dimension.
:param mlp_dropout: The amount of dropout that should be applied between the two layers of the classifier MLP.
:param final_activation: Activation function to normalize the logits default is Identity.
:param num_non_image_features: Number of non imaging features will be used in the model.
:param kernel_size_per_encoding_block: The size of the kernels per encoding block, assumed to be the same
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
performs convolutions only in X and Y.
:param stride_size_per_encoding_block: The stride size for the encoding block, assumed to be the same
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
if a single tuple is provided. Otherwise the list of tuples must match num_encoder_blocks. Default
reduces spatial dimensions only in X and Y.
:param encoder_dimensionality_reduction_factor: how to reduce the dimensionality of the image features in the
combined model to balance with non imaging features.
combined model to balance with non imaging features.
:param scan_size: should be a tuple representing 3D tensor shape and if specified it's usedd in initializing
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
gated pooling or z-adaptive. The first element should be representing the z-direction for classification images
"""
super().__init__(imaging_feature_type=imaging_feature_type,
encode_channels_jointly=encode_channels_jointly,
@ -369,12 +383,13 @@ def create_mlp(input_num_feature_channels: int,
hidden_layer_num_feature_channels: Optional[int] = None) -> MLP:
"""
Create an MLP with 1 hidden layer.
:param input_num_feature_channels: The number of input channels to the first MLP layer.
:param dropout: The drop out factor that should be applied between the first and second MLP layer.
:param final_output_channels: if provided, the final number of output channels.
:param final_layer: if provided, the final (activation) layer to apply
:param hidden_layer_num_feature_channels: if provided, will be used to create hidden layers, If None then
input_num_feature_channels // 2 will be used to create the hidden layer.
input_num_feature_channels // 2 will be used to create the hidden layer.
:return:
"""
layers: List[torch.nn.Module] = []

Просмотреть файл

@ -160,11 +160,12 @@ class MultiSegmentationEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
) -> None:
"""
:param encode_channels_jointly: If False, create an encoder structure separately for each channel. If True,
encode all channels jointly (convolution will run over all channels).
encode all channels jointly (convolution will run over all channels).
:param num_image_channels: Number of channels of the input. Input is expected to be of size
B x (num_image_channels * 10) x Z x Y x X, where B is the batch dimension.
B x (num_image_channels * 10) x Z x Y x X, where B is the batch dimension.
:param use_mixed_precision: If True, assume that training happens with mixed precision. Segmentations will
be converted to float16 tensors right away. If False, segmentations will be converted to float32 tensors.
be converted to float16 tensors right away. If False, segmentations will be converted to float32 tensors.
"""
super().__init__()
self.encoder_input_channels = \
@ -194,6 +195,7 @@ class MultiSegmentationEncoder(DeviceAwareModule[ScalarItem, torch.Tensor]):
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
"""
Transforms a classification item into a torch.Tensor that the forward pass can consume
:param item: ClassificationItem
:return: Tensor
"""

Просмотреть файл

@ -27,9 +27,11 @@ class ComplexModel(BaseSegmentationModel):
crop_size_constraints: Optional[CropSizeConstraints] = None):
"""
Creates a new instance of the class.
:param args: The full model configuration.
:param full_channels_list: A vector of channel sizes. First entry is the number of image channels,
then all feature channels, then the number of classes.
then all feature channels, then the number of classes.
:param network_definition:
:param crop_size_constraints: The size constraints for the training crop size.
"""

Просмотреть файл

@ -59,6 +59,7 @@ class MLP(DeviceAwareModule[ScalarItem, torch.Tensor]):
def get_input_tensors(self, item: ScalarItem) -> List[torch.Tensor]:
"""
Transforms a classification item into a torch.Tensor that the forward pass can consume
:param item: ClassificationItem
:return: Tensor
"""

Просмотреть файл

@ -26,15 +26,18 @@ class UNet2D(UNet3D):
"""
Initializes a 2D UNet model, where the input image is expected as a 3 dimensional tensor with a vanishing
Z dimension.
:param input_image_channels: The number of image channels that the model should consume.
:param initial_feature_channels: The number of feature maps used in the model in the first convolution layer.
Subsequent layers will contain number of feature maps that are multiples of `initial_channels`
Subsequent layers will contain number of feature maps that are multiples of `initial_channels`
(2^(image_level) * initial_channels)
:param num_classes: Number of output classes
:param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths)
:param downsampling_dilation: An additional dilation that is used in the second convolution layer in each
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
choice is (1, 2, 2), to increase the receptive field only in X and Y.
:param padding_mode: The type of padding that should be applied.
"""
super().__init__(input_image_channels=input_image_channels,

Просмотреть файл

@ -30,6 +30,7 @@ class UNet3D(BaseSegmentationModel):
4) Support for more downsampling operations to capture larger image context and improve the performance.
The network has `num_downsampling_paths` downsampling steps on the encoding side and same number upsampling steps
on the decoding side.
:param num_downsampling_paths: Number of downsampling paths used in Unet model (default 4 image level are used)
:param num_classes: Number of output segmentation classes
:param kernel_size: Spatial support of convolution kernels used in Unet model
@ -39,11 +40,15 @@ class UNet3D(BaseSegmentationModel):
"""
Implements upsampling block for UNet architecture. The operations carried out on the input tensor are
1) Upsampling via strided convolutions 2) Concatenating the skip connection tensor 3) Two convolution layers
:param channels: A tuple containing the number of input and output channels
:param upsample_kernel_size: Spatial support of upsampling kernels. If an integer is provided, the same value
will be repeated for all three dimensions. For non-cubic kernels please pass a list or tuple with three elements
will be repeated for all three dimensions. For non-cubic kernels please pass a list or tuple with three
elements.
:param upsampling_stride: Upsamling factor used in deconvolutional layer. Similar to the `upsample_kernel_size`
parameter, if an integer is passed, the same upsampling factor will be used for all three dimensions.
parameter, if an integer is passed, the same upsampling factor will be used for all three dimensions.
:param activation: Linear/Non-linear activation function that is used after linear deconv/conv mappings.
:param depth: The depth inside the UNet at which the layer operates. This is only for diagnostic purposes.
"""
@ -120,11 +125,14 @@ class UNet3D(BaseSegmentationModel):
Implements a EncodeBlock for UNet.
A EncodeBlock is two BasicLayers without dilation and with same padding.
The first of those BasicLayer can use stride > 1, and hence will downsample.
:param channels: A list containing two elements representing the number of input and output channels
:param kernel_size: Spatial support of convolution kernels. If an integer is provided, the same value will
be repeated for all three dimensions. For non-cubic kernels please pass a tuple with three elements.
be repeated for all three dimensions. For non-cubic kernels please pass a tuple with three elements.
:param downsampling_stride: Downsampling factor used in the first convolutional layer. If an integer is
passed, the same downsampling factor will be used for all three dimensions.
passed, the same downsampling factor will be used for all three dimensions.
:param dilation: Dilation of convolution kernels - If set to > 1, kernels capture content from wider range.
:param activation: Linear/Non-linear activation function that is used after linear convolution mappings.
:param use_residual: If set to True, block2 learns the residuals while preserving the output of block1
@ -182,18 +190,22 @@ class UNet3D(BaseSegmentationModel):
crop_size_constraints=crop_size_constraints)
"""
Modified 3D-Unet Class
:param input_image_channels: Number of image channels (scans) that are fed into the model.
:param initial_feature_channels: Number of feature-maps used in the model - Subsequent layers will contain
number
number
of featuremaps that is multiples of `initial_feature_channels` (e.g. 2^(image_level) * initial_feature_channels)
:param num_classes: Number of output classes
:param kernel_size: Spatial support of conv kernels in each spatial axis.
:param num_downsampling_paths: Number of image levels used in Unet (in encoding and decoding paths)
:param downsampling_factor: Spatial downsampling factor for each tensor axis (depth, width, height). This will
be used as the stride for the first convolution layer in each encoder block.
be used as the stride for the first convolution layer in each encoder block.
:param downsampling_dilation: An additional dilation that is used in the second convolution layer in each
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
of the encoding blocks of the UNet. This can be used to increase the receptive field of the network. A good
choice is (1, 2, 2), to increase the receptive field only in X and Y.
:param crop_size: The size of the crop that should be used for training.
"""

Просмотреть файл

@ -14,15 +14,18 @@ class BasicLayer(torch.nn.Module):
"""
A Basic Layer applies a 3D convolution and BatchNorm with the given channels, kernel_size, and dilation.
The output of BatchNorm layer is passed through an activation function and its output is returned.
:param channels: Number of input and output channels.
:param kernel_size: Spatial support of convolution kernels
:param stride: Kernel stride lenght for convolution op
:param padding: Feature map padding after convolution op {"constant/zero", "no_padding"}. When it is set to
"no_padding", no padding is applied. For "constant", feature-map tensor size is kept the same at the output by
"no_padding", no padding is applied. For "constant", feature-map tensor size is kept the same at the output by
padding with zeros.
:param dilation: Kernel dilation used in convolution layer
:param use_bias: If set to True, a bias parameter will be added to the layer. Default is set to False as
batch normalisation layer has an affine parameter which are used applied after the bias term is added.
batch normalisation layer has an affine parameter which are used applied after the bias term is added.
:param activation: Activation layer (e.g. nonlinearity) to be used after the convolution and batch norm operations.
"""

Просмотреть файл

@ -21,13 +21,16 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
super().__init__(smoothing_eps)
"""
Multi-class cross entropy loss.
:param class_weight_power: if 1.0, weights the cross-entropy term for each class equally.
Class weights are inversely proportional to the number
of pixels belonging to each class, raised to class_weight_power
:param focal_loss_gamma: Gamma term used in focal loss to weight negative log-likelihood term:
https://arxiv.org/pdf/1708.02002.pdf equation(4-5).
When gamma equals to zero, it is equivalent to standard
CE with no class balancing. (Gamma >= 0.0)
:param ignore_index: Specifies a target value that is ignored and does not contribute
to the input gradient
"""
@ -60,6 +63,7 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
def get_focal_loss_pixel_weights(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes weights for each pixel/sample inversely proportional to the posterior likelihood.
:param logits: Logits tensor.
:param target: Target label tensor in one-hot encoding.
"""
@ -103,6 +107,7 @@ class CrossEntropyLoss(SupervisedLearningCriterion):
Wrapper for multi-class cross entropy function implemented in PyTorch.
The implementation supports tensors with arbitrary spatial dimension.
Input logits are normalised internally in `F.cross_entropy` function.
:param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC
:param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
"""

Просмотреть файл

@ -14,6 +14,7 @@ class MixtureLoss(SupervisedLearningCriterion):
def __init__(self, components: List[Tuple[float, SupervisedLearningCriterion]]):
"""
Loss function defined as a weighted mixture (interpolation) of other loss functions.
:param components: a non-empty list of weights and loss function instances.
"""
super().__init__()
@ -25,6 +26,7 @@ class MixtureLoss(SupervisedLearningCriterion):
"""
Wrapper for mixture loss function implemented in PyTorch. Arguments should be suitable for the
component loss functions, typically:
:param output: Class logits (unnormalised), e.g. in 3D : BxCxWxHxD or in 1D BxC
:param target: Target labels encoded in one-hot representation, e.g. in 3D BxCxWxHxD or in 1D BxC
"""

Просмотреть файл

@ -24,9 +24,10 @@ class SoftDiceLoss(SupervisedLearningCriterion):
"""
:param eps: A small constant to smooth Sorensen-Dice Loss function. Additionally, it avoids division by zero.
:param apply_softmax: If true, the input to the loss function will be first fed through a Softmax operation.
If false, the input to the loss function will be used as is.
If false, the input to the loss function will be used as is.
:param class_weight_power: power to raise 1/C to, where C is the number of voxels in each class. Should be
non-negative to help increase accuracy on small structures.
non-negative to help increase accuracy on small structures.
"""
super().__init__()
#: Small value to avoid division by zero errors.

Просмотреть файл

@ -14,6 +14,7 @@ def move_to_device(input_tensors: List[torch.Tensor],
non_blocking: bool = False) -> Iterable[torch.Tensor]:
"""
Updates the memory location of tensors stored in a list.
:param input_tensors: List of torch tensors
:param target_device: Target device (e.g. cuda:0, cuda:1, etc). If the device is None, the tensors are not moved.
:param non_blocking: bool
@ -40,6 +41,7 @@ def group_layers_with_balanced_memory(inputs: List[torch.nn.Module],
summary: Optional[OrderedDict]) -> Generator:
"""
Groups layers in the model in a balanced way as such each group has similar size of memory requirement
:param inputs: List of input torch modules.
:param num_groups: Number of groups to be produced.
:param summary: Model summary of the input layers which is used to retrieve memory requirements.
@ -121,6 +123,7 @@ def partition_layers(layers: List[torch.nn.Module],
target_devices: List[torch.device]) -> None:
"""
Splits the models into multiple chunks and assigns each sub-model to a particular GPU
:param layers: The layers to partition
:param summary: Model architecture summary to use for partitioning
:param target_devices: The devices to partition layers into

Просмотреть файл

@ -45,8 +45,9 @@ def create_parser(yaml_file_path: Path) -> ParserResult:
Create a parser for all runner arguments, even though we are only using a subset of the arguments.
This way, we can get secrets handling in a consistent way.
In particular, this will create arguments for
--local_dataset
--azure_dataset_id
* ``--local_dataset``
* ``--azure_dataset_id``
"""
parser = create_runner_parser(SegmentationModelBase)
NormalizeAndVisualizeConfig.add_args(parser)
@ -67,9 +68,11 @@ def get_configs(default_model_config: SegmentationModelBase,
def main(yaml_file_path: Path) -> None:
"""
Invoke either by
* specifying a model, '--model Lung'
* or specifying dataset and normalization parameters separately: --azure_dataset_id=foo --norm_method=None
In addition, the arguments '--image_channel' and '--gt_channel' must be specified (see below).
* specifying a model, ``--model Lung``
* or specifying dataset and normalization parameters separately: ``--azure_dataset_id=foo --norm_method=None``
In addition, the arguments ``--image_channel`` and ``--gt_channel`` must be specified.
"""
config, runner_config, args = get_configs(SegmentationModelBase(should_validate=False), yaml_file_path)
dataset_config = DatasetConfig(name=config.azure_dataset_id,

Просмотреть файл

@ -242,6 +242,7 @@ def robust_mean_std(data: np.ndarray) -> Tuple[float, float, float, float]:
Computes robust estimates of mean and standard deviation in the given array.
The median is the robust estimate for the mean, the standard deviation is computed from the
inter-quartile ranges.
:param data: The data for which mean and std should be computed.
:return: A 4-tuple with values (median, robust_std, minimum data value, maximum data value)
"""
@ -272,9 +273,11 @@ def mri_window(image_in: np.ndarray,
around the mean of the remaining values and with a range controlled by the standard deviation and the sharpen
input parameter. The larger sharpen is, the wider the range. The resulting values are the normalised to the given
output_range, with values below and above the range being set the the boundary values.
:param image_in: The image to normalize.
:param mask: Consider only pixel values of the input image for which the mask is non-zero. If None the whole
image is considered.
image is considered.
:param output_range: The desired value range of the result image.
:param sharpen: number of standard deviation either side of mean to include in the window
:param tail: Default 1, allow window range to include more of tail of distribution.

Просмотреть файл

@ -49,12 +49,14 @@ class EnsemblePipeline(FullImageInferencePipelineBase):
aggregation_type: EnsembleAggregationType) -> InferencePipeline.Result:
"""
Helper method to aggregate results from multiple inference pipelines, based on the aggregation type provided.
:param results: inference pipeline results to aggregate. This may be a Generator to prevent multiple large
posterior arrays being held at the same time. The first element of the sequence is modified in place to
posterior arrays being held at the same time. The first element of the sequence is modified in place to
minimize memory use.
:param aggregation_type: aggregation function to use to combine the results.
:return: InferenceResult: contains a Segmentation for each of the classes and their posterior
probabilities.
probabilities.
"""
if aggregation_type != EnsembleAggregationType.Average:
raise NotImplementedError(f"Ensembling is not implemented for aggregation type: {aggregation_type}")
@ -78,12 +80,13 @@ class EnsemblePipeline(FullImageInferencePipelineBase):
"""
Performs a single inference pass for each model in the ensemble, and aggregates the results
based on the provided aggregation type.
:param image_channels: The input image channels to perform inference on in format: Channels x Z x Y x X.
:param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
:param mask: A binary image used to ignore results outside it in format: Z x Y x X.
:param patient_id: The identifier of the patient this image belongs to.
:return InferenceResult: that contains Segmentation for each of the classes and their posterior
probabilities.
:return: InferenceResult: that contains Segmentation for each of the classes and their posterior
probabilities.
"""
logging.info(f"Ensembling inference pipelines ({self._get_pipeline_ids()}) "
f"predictions for patient: {patient_id}, "

Просмотреть файл

@ -56,6 +56,7 @@ class FullImageInferencePipelineBase(InferencePipelineBase):
"""
Perform connected component analysis to update segmentation with largest
connected component based on the configurations
:param results: inference results to post-process
:return: post-processed version of results
"""
@ -199,12 +200,14 @@ class InferencePipeline(FullImageInferencePipelineBase):
Creates an instance of the inference pipeline for a given epoch from a stored checkpoint.
After loading, the model parameters are checked for NaN and Infinity values.
If there is no checkpoint file for the given epoch, return None.
:param path_to_checkpoint: The path to the checkpoint that we want to load
model_config.checkpoint_folder
model_config.checkpoint_folder
:param model_config: Model related configurations.
:param pipeline_id: Numeric identifier for the pipeline (useful for logging when ensembling)
:return InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint
file for this epoch.
:return: InferencePipeline: an instantiated inference pipeline instance, or None if there was no checkpoint
file for this epoch.
"""
if not path_to_checkpoint.is_file():
# not raising a value error here: This is used to create individual pipelines for ensembles,
@ -244,7 +247,7 @@ class InferencePipeline(FullImageInferencePipelineBase):
:param voxel_spacing_mm: Voxel spacing to use for each dimension in (Z x Y x X) order
:param mask: A binary image used to ignore results outside it in format: Z x Y x X.
:param patient_id: The identifier of the patient this image belongs to (defaults to 0 if None provided).
:return InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
:return: InferenceResult: that contains Segmentation for each of the classes and their posterior probabilities.
"""
torch.cuda.empty_cache()
if image_channels is None:

Просмотреть файл

@ -140,6 +140,7 @@ class ScalarEnsemblePipeline(ScalarInferencePipelineBase):
config: ScalarModelBase) -> ScalarEnsemblePipeline:
"""
Creates an ensemble pipeline from a list of checkpoints.
:param paths_to_checkpoint: List of paths to the checkpoints to be recovered.
:param config: Model configuration information.
:return:
@ -179,8 +180,9 @@ class ScalarEnsemblePipeline(ScalarInferencePipelineBase):
def aggregate_model_outputs(self, model_outputs: torch.Tensor) -> torch.Tensor:
"""
Aggregates the forward pass results from the individual models in the ensemble.
:param model_outputs: List of model outputs for every model in the ensemble.
(Number of ensembles) x (batch_size) x 1
(Number of ensembles) x (batch_size) x 1
"""
# aggregate model outputs
if self.aggregation_type == EnsembleAggregationType.Average:

Просмотреть файл

@ -28,6 +28,7 @@ def is_val_dice(name: str) -> bool:
"""
Returns true if the given metric name is a Dice score on the validation set,
for a class that is not the background class.
:param name:
:return:
"""
@ -37,6 +38,7 @@ def is_val_dice(name: str) -> bool:
def get_val_dice_names(metric_names: Iterable[str]) -> List[str]:
"""
Returns a list of those metric names from the argument that fulfill the is_val_dice predicate.
:param metric_names:
:return:
"""
@ -47,6 +49,7 @@ def plot_loss_per_epoch(metrics: Dict[str, Any], metric_name: str, label: Option
"""
Adds a plot of loss (y-axis) versus epoch (x-axis) to the current plot, if the metric
is present in the metrics dictionary.
:param metrics: A dictionary of metrics.
:param metric_name: The name of the single metric to plot.
:param label: The label for the series that will be plotted.
@ -64,9 +67,10 @@ def plot_loss_per_epoch(metrics: Dict[str, Any], metric_name: str, label: Option
def plot_val_dice_per_epoch(metrics: Dict[str, Any]) -> int:
"""
Creates a plot of all validation Dice scores per epoch, for all classes apart from background.
:param metrics:
:return: The number of series that were plotted in the graph. Can return 0 if the metrics dictionary
does not contain any validation Dice score.
does not contain any validation Dice score.
"""
plt.clf()
series_count = 0
@ -83,6 +87,7 @@ def plot_val_dice_per_epoch(metrics: Dict[str, Any]) -> int:
def add_legend(series_count: int) -> None:
"""
Adds a legend to the present plot, with the column layout depending on the number of series.
:param series_count:
:return:
"""
@ -93,6 +98,7 @@ def add_legend(series_count: int) -> None:
def resize_and_save(width_inch: int, height_inch: int, filename: PathOrString, dpi: int = 150) -> None:
"""
Resizes the present figure to the given (width, height) in inches, and saves it to the given filename.
:param width_inch: The width of the figure in inches.
:param height_inch: The height of the figure in inches.
:param filename: The filename to save to.
@ -113,13 +119,17 @@ def plot_image_and_label_contour(image: np.ndarray,
"""
Creates a plot that shows the given 2D image in greyscale, and overlays a contour that shows
where the 'labels' array has value 1.
:param image: A 2D image
:param labels: A binary 2D image, or a list of binary 2D images. A contour will be plotted for each of those
binary images.
binary images.
:param contour_arguments: A dictionary of keyword arguments that will be passed directly into matplotlib's
contour function. Can also be a list of dictionaries, with one dict per entry in the 'labels' argument.
contour function. Can also be a list of dictionaries, with one dict per entry in the 'labels' argument.
:param image_range: If provided, the image will be plotted using the given range for the color limits.
If None, the minimum and maximum image values will be mapped to the endpoints of the color map.
If None, the minimum and maximum image values will be mapped to the endpoints of the color map.
:param plot_file_name: The file name that should be used to save the plot.
"""
if image.ndim != 2:
@ -183,6 +193,7 @@ def plot_before_after_statistics(image_before: np.ndarray,
that were obtained before and after a transformation of pixel values.
The plot contains histograms, box plots, and visualizations of a single XY slice at z_slice.
If a mask argument is provided, only the image pixel values inside of the mask will be plotted.
:param image_before: The first image for which to plot statistics.
:param image_after: The second image for which to plot statistics.
:param mask: Indicators with 1 for foreground, 0 for background. If None, plot statistics for all image pixels.
@ -237,10 +248,13 @@ def plot_normalization_result(loaded_images: Sample,
The first plot contains pixel value histograms before and after photometric normalization.
The second plot contains the normalized image, overlayed with contours for the foreground pixels,
at the slice where the foreground has most pixels.
:param loaded_images: An instance of Sample with the image and the labels. The first channel of the image will
be plotted.
be plotted.
:param image_range: The image value range that will be mapped to the color map. If None, the full image range
will be mapped to the colormap.
will be mapped to the colormap.
:param normalizer: The photometric normalization that should be applied.
:param result_folder: The folder into which the resulting PNG files should be written.
:param result_prefix: The prefix for all output filenames.
@ -283,13 +297,15 @@ def plot_contours_for_all_classes(sample: Sample,
"""
Creates a plot with the image, the ground truth, and the predicted segmentation overlaid. One plot is created
for each class, each plotting the Z slice where the ground truth has most pixels.
:param sample: The image sample, with the photonormalized image and the ground truth labels.
:param segmentation: The predicted segmentation: multi-value, size Z x Y x X.
:param foreground_class_names: The names of all classes, excluding the background class.
:param result_folder: The folder into which the resulting plot PNG files should be written.
:param result_prefix: A string prefix that will be used for all plots.
:param image_range: The minimum and maximum image values that will be mapped to the color map ranges.
If None, use the actual min and max values.
If None, use the actual min and max values.
:param channel_index: The index of the image channel that should be plotted.
:return: The paths to all generated PNG files.
"""
@ -337,6 +353,7 @@ def segmentation_and_groundtruth_plot(prediction: np.ndarray, ground_truth: np.n
"""
Plot predicted and the ground truth segmentations. Always plots the middle slice (to match surface distance
plots), which can sometimes lead to an empty plot.
:param prediction: 3D volume (X x Y x Z) of predicted segmentation
:param ground_truth: 3D volume (X x Y x Z) of ground truth segmentation
:param subject_id: ID of subject for annotating plot
@ -389,6 +406,7 @@ def surface_distance_ground_truth_plot(ct: np.ndarray, ground_truth: np.ndarray,
annotator: str = None) -> None:
"""
Plot surface distances where prediction > 0, with ground truth contour
:param ct: CT scan
:param ground_truth: Ground truth segmentation
:param sds_full: Surface distances (full= where prediction > 0)
@ -471,10 +489,12 @@ def scan_with_transparent_overlay(scan: np.ndarray,
information in the range [0, 1]. High values of the `overlay` are shown as opaque red, low values as transparent
red.
Plots are created in the current axis.
:param scan: A 3-dimensional image in (Z, Y, X) ordering
:param overlay: A 3-dimensional image in (Z, Y, X) ordering, with values between 0 and 1.
:param dimension: The array dimension along with the plot should be created. dimension=0 will generate
an axial slice.
an axial slice.
:param position: The index in the chosen dimension where the plot should be created.
:param spacing: The tuple of voxel spacings, in (Z, Y, X) order.
"""

Просмотреть файл

@ -73,6 +73,7 @@ def get_dataframe_with_exact_label_matches(metrics_df: pd.DataFrame,
The dataframe must have at least the following columns (defined in the LoggingColumns enum):
LoggingColumns.Hue, LoggingColumns.Patient, LoggingColumns.Label, LoggingColumns.ModelOutput.
Any other columns will be ignored.
:param prediction_target_set_to_match: The set of prediction targets to which each sample is compared
:param all_prediction_targets: The entire set of prediction targets on which the model is trained
:param thresholds_per_prediction_target: Thresholds per prediction target to decide if model has predicted True or
@ -139,9 +140,9 @@ def print_metrics_for_thresholded_output_for_all_prediction_targets(csv_to_set_o
prediction targets that exist in the dataset (i.e. for every subset of classes that occur in the dataset).
:param csv_to_set_optimal_threshold: Csv written during inference time for the val set. This is used to determine
the optimal threshold for classification.
the optimal threshold for classification.
:param csv_to_compute_metrics: Csv written during inference time for the test set. Metrics are calculated for
this csv.
this csv.
:param config: Model config
"""

Просмотреть файл

@ -69,9 +69,11 @@ def read_csv_and_filter_prediction_target(csv: Path, prediction_target: str,
:param csv: Path to the metrics CSV file. Must contain at least the following columns (defined in the LoggingColumns
enum): LoggingColumns.Patient, LoggingColumns.Hue.
:param prediction_target: Target ("hue") by which to filter.
:param crossval_split_index: If specified, filter rows only for the respective run (requires
LoggingColumns.CrossValidationSplitIndex).
:param data_split: If specified, filter rows by Train/Val/Test (requires LoggingColumns.DataSplit).
:param epoch: If specified, filter rows for given epoch (default: last epoch only; requires LoggingColumns.Epoch).
:return: Filtered dataframe.
@ -122,9 +124,11 @@ def get_labels_and_predictions(csv: Path, prediction_target: str,
:param csv: Path to the metrics CSV file. Must contain at least the following columns (defined in the LoggingColumns
enum): LoggingColumns.Patient, LoggingColumns.Hue.
:param prediction_target: Target ("hue") by which to filter.
:param crossval_split_index: If specified, filter rows only for the respective run (requires
LoggingColumns.CrossValidationSplitIndex).
:param data_split: If specified, filter rows by Train/Val/Test (requires LoggingColumns.DataSplit).
:param epoch: If specified, filter rows for given epoch (default: last epoch only; requires LoggingColumns.Epoch).
:return: Filtered labels and model outputs.
@ -150,6 +154,7 @@ def get_labels_and_predictions_from_dataframe(df: pd.DataFrame) -> LabelsAndPred
def format_pr_or_roc_axes(plot_type: str, ax: Axes) -> None:
"""
Format PR or ROC plot with appropriate title, axis labels, limits, and grid.
:param plot_type: Either 'pr' or 'roc'.
:param ax: Axes object to format.
"""
@ -171,9 +176,11 @@ def plot_pr_and_roc_curves(labels_and_model_outputs: LabelsAndPredictions, axs:
plot_kwargs: Optional[Dict[str, Any]] = None) -> None:
"""
Given labels and model outputs, plot the ROC and PR curves.
:param labels_and_model_outputs:
:param axs: Pair of axes objects onto which to plot the ROC and PR curves, respectively. New axes are created by
default.
default.
:param plot_kwargs: Plotting options to be passed to both `ax.plot(...)` calls.
"""
if axs is None:
@ -202,15 +209,18 @@ def plot_scores_and_summary(all_labels_and_model_outputs: Sequence[LabelsAndPred
Each plotted curve is interpolated onto a common horizontal grid, and the median and CI are computed vertically
at each horizontal location.
:param all_labels_and_model_outputs: Collection of ground-truth labels and model predictions (e.g. for various
cross-validation runs).
cross-validation runs).
:param scoring_fn: A scoring function mapping a `LabelsAndPredictions` object to X and Y coordinates for plotting.
:param interval_width: A value in [0, 1] representing what fraction of the data should be contained in
the shaded area. The edges of the interval are `median +/- interval_width/2`.
the shaded area. The edges of the interval are `median +/- interval_width/2`.
:param ax: Axes object onto which to plot (default: use current axes).
:return: A tuple of `(line_handles, summary_handle)` to use in setting a legend for the plot: `line_handles` is a
list corresponding to the curves for each `LabelsAndPredictions`, and `summary_handle` references the median line
and shaded CI area.
list corresponding to the curves for each `LabelsAndPredictions`, and `summary_handle` references the median line
and shaded CI area.
"""
if ax is None:
ax = plt.gca()
@ -236,10 +246,12 @@ def plot_pr_and_roc_curves_crossval(all_labels_and_model_outputs: Sequence[Label
"""
Given a list of LabelsAndPredictions objects, plot the corresponding ROC and PR curves, along with median line and
shaded 80% confidence interval (computed over TPRs and precisions for each fixed FPR and recall value).
:param all_labels_and_model_outputs: Collection of ground-truth labels and model predictions (e.g. for various
cross-validation runs).
cross-validation runs).
:param axs: Pair of axes objects onto which to plot the ROC and PR curves, respectively. New axes are created by
default.
default.
"""
if axs is None:
_, axs = plt.subplots(1, 2)
@ -276,11 +288,12 @@ def plot_pr_and_roc_curves_from_csv(metrics_csv: Path, config: ScalarModelBase,
"""
Given the CSV written during inference time and the model config, plot the ROC and PR curves for all prediction
targets.
:param metrics_csv: Path to the metrics CSV file.
:param config: Model config.
:param data_split: Whether to filter the CSV file for Train, Val, or Test results (default: no filtering).
:param is_crossval_report: If True, assumes CSV contains results for multiple cross-validation runs and plots the
curves with median and confidence intervals. Otherwise, plots curves for a single run.
curves with median and confidence intervals. Otherwise, plots curves for a single run.
"""
for prediction_target in config.target_names:
print_header(f"Class: {prediction_target}", level=3)
@ -300,8 +313,10 @@ def get_metric(predictions_to_set_optimal_threshold: LabelsAndPredictions,
optimal_threshold: Optional[float] = None) -> float:
"""
Given LabelsAndPredictions objects for the validation and test sets, return the specified metric.
:param predictions_to_set_optimal_threshold: This set of ground truth labels and model predictions is used to
determine the optimal threshold for classification.
determine the optimal threshold for classification.
:param predictions_to_compute_metrics: The set of labels and model outputs to calculate metrics for.
:param metric: The name of the metric to calculate.
:param optimal_threshold: If provided, use this threshold instead of calculating an optimal threshold.
@ -351,6 +366,7 @@ def get_all_metrics(predictions_to_set_optimal_threshold: LabelsAndPredictions,
is_thresholded: bool = False) -> Dict[str, float]:
"""
Given LabelsAndPredictions objects for the validation and test sets, compute some metrics.
:param predictions_to_set_optimal_threshold: This is used to determine the optimal threshold for classification.
:param predictions_to_compute_metrics: Metrics are calculated for this set.
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
@ -379,6 +395,7 @@ def print_metrics(predictions_to_set_optimal_threshold: LabelsAndPredictions,
is_thresholded: bool = False) -> None:
"""
Given LabelsAndPredictions objects for the validation and test sets, print out some metrics.
:param predictions_to_set_optimal_threshold: This is used to determine the optimal threshold for classification.
:param predictions_to_compute_metrics: Metrics are calculated for this set.
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
@ -402,16 +419,21 @@ def get_metrics_table_for_prediction_target(csv_to_set_optimal_threshold: Path,
:param csv_to_set_optimal_threshold: CSV written during inference time for the val set. This is used to determine
the optimal threshold for classification.
:param csv_to_compute_metrics: CSV written during inference time for the test set. Metrics are calculated for
this CSV.
:param config: Model config
:param prediction_target: The prediction target for which to compute metrics.
:param data_split_to_set_optimal_threshold: Whether to filter the validation CSV file for Train, Val, or Test
results (default: no filtering).
:param data_split_to_compute_metrics: Whether to filter the test CSV file for Train, Val, or Test results
(default: no filtering).
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
or are floating point numbers.
:param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and formats the
metrics along with means and standard deviations. Otherwise, collect metrics for a single run.
:return: Tuple of rows and header, where each row and the header are lists of strings of same length (2 if
@ -468,15 +490,20 @@ def print_metrics_for_all_prediction_targets(csv_to_set_optimal_threshold: Path,
:param csv_to_set_optimal_threshold: CSV written during inference time for the val set. This is used to determine
the optimal threshold for classification.
:param csv_to_compute_metrics: CSV written during inference time for the test set. Metrics are calculated for
this CSV.
:param config: Model config
:param data_split_to_set_optimal_threshold: Whether to filter the validation CSV file for Train, Val, or Test
results (default: no filtering).
:param data_split_to_compute_metrics: Whether to filter the test CSV file for Train, Val, or Test results
(default: no filtering).
:param is_thresholded: Whether the model outputs are binary (they have been thresholded at some point)
or are floating point numbers.
:param is_crossval_report: If True, assumes CSVs contain results for multiple cross-validation runs and prints the
metrics along with means and standard deviations. Otherwise, prints metrics for a single run.
"""
@ -563,11 +590,14 @@ def print_k_best_and_worst_performing(val_metrics_csv: Path,
"""
Print the top "k" best predictions (i.e. correct classifications where the model was the most certain) and the
top "k" worst predictions (i.e. misclassifications where the model was the most confident).
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
used to determine the thresholds for predicting labels on the test set. The best and worst
performing subjects will not be printed out for this csv.
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
best and worst performing subjects will be printed out.
:param k: Number of subjects of each category to print out.
:param prediction_target: The class label to filter on
"""
@ -605,6 +635,7 @@ def get_image_filepath_from_subject_id(subject_id: str,
config: ScalarModelBase) -> List[Path]:
"""
Return the filepaths for images associated with a subject. If the subject is not found, raises a ValueError.
:param subject_id: Subject to retrive image for
:param dataset: scalar dataset object
:param config: model config
@ -623,6 +654,7 @@ def get_image_labels_from_subject_id(subject_id: str,
config: ScalarModelBase) -> List[str]:
"""
Return the ground truth labels associated with a subject. If the subject is not found, raises a ValueError.
:param subject_id: Subject to retrive image for
:param dataset: scalar dataset object
:param config: model config
@ -657,6 +689,7 @@ def get_image_outputs_from_subject_id(subject_id: str,
def plot_image_from_filepath(filepath: Path, im_width: int) -> bool:
"""
Plots a 2D image given the filepath. Returns false if the image could not be plotted (for example, if it was 3D).
:param filepath: Path to image
:param im_width: Display width for image
:return: True if image was plotted, False otherwise
@ -693,6 +726,7 @@ def plot_image_for_subject(subject_id: str,
metrics_df: Optional[pd.DataFrame] = None) -> None:
"""
Given a subject ID, plots the corresponding image.
:param subject_id: Subject to plot image for
:param dataset: scalar dataset object
:param im_width: Display width for image
@ -741,11 +775,14 @@ def plot_k_best_and_worst_performing(val_metrics_csv: Path, test_metrics_csv: Pa
"""
Plot images for the top "k" best predictions (i.e. correct classifications where the model was the most certain)
and the top "k" worst predictions (i.e. misclassifications where the model was the most confident).
:param val_metrics_csv: Path to one of the metrics csvs written during inference. This set of metrics will be
used to determine the thresholds for predicting labels on the test set. The best and worst
performing subjects will not be printed out for this csv.
:param test_metrics_csv: Path to one of the metrics csvs written during inference. This is the csv for which
best and worst performing subjects will be printed out.
:param k: Number of subjects of each category to print out.
:param prediction_target: The class label to filter on
:param config: scalar model config object

Просмотреть файл

@ -26,6 +26,7 @@ reports_folder = "reports"
def get_ipynb_report_name(report_type: str) -> str:
"""
Constructs the name of the report (as an ipython notebook).
:param report_type: suffix describing the report, added to the filename
:return:
"""
@ -35,6 +36,7 @@ def get_ipynb_report_name(report_type: str) -> str:
def get_html_report_name(report_type: str) -> str:
"""
Constructs the name of the report (as an html file).
:param report_type: suffix describing the report, added to the filename
:return:
"""
@ -49,6 +51,7 @@ def print_header(message: str, level: int = 2) -> None:
"""
Displays a message, and afterwards repeat it as Markdown with the given indentation level (level=1 is the
outermost, `# Foo`.
:param message: The header string to display.
:param level: The Markdown indentation level. level=1 for top level, level=3 for `### Foo`
"""
@ -59,6 +62,7 @@ def print_header(message: str, level: int = 2) -> None:
def print_table(rows: Sequence[Sequence[str]], header: Optional[Sequence[str]] = None) -> None:
"""
Displays the provided content in a formatted HTML table, with optional column headers.
:param rows: List of rows, where each row is a list of string-valued cell contents.
:param header: List of column headers. If given, this special first row is rendered with emphasis.
"""
@ -74,6 +78,7 @@ def print_table(rows: Sequence[Sequence[str]], header: Optional[Sequence[str]] =
def generate_notebook(template_notebook: Path, notebook_params: Dict, result_notebook: Path) -> Path:
"""
Generates a notebook report as jupyter notebook and html page
:param template_notebook: path to template notebook
:param notebook_params: parameters for the notebook
:param result_notebook: the path for the executed notebook

Просмотреть файл

@ -40,6 +40,7 @@ def plot_scores_for_csv(path_csv: str, outlier_range: float, max_row_count: int)
def display_without_index(df: pd.DataFrame) -> None:
"""
Prints the given dataframe as HTML via the `display` function, but without the index column.
:param df: The dataframe to print.
"""
display(HTML(df.to_html(index=False)))
@ -52,12 +53,13 @@ def display_metric(df: pd.DataFrame,
high_values_are_good: bool) -> None:
"""
Displays a dataframe with a metric per structure, first showing the
:param max_row_count: The number of rows to print when showing the lowest score patients
:param df: The dataframe with metrics per structure
:param metric_name: The metric to sort by.
:param outlier_range: The standard deviation range data points must fall outside of to be considered an outlier
:param high_values_are_good: If true, high values of the metric indicate good performance. If false, low
values indicate good performance.
values indicate good performance.
"""
print_header(metric_name, level=2)
# Display with best structures first.
@ -80,11 +82,13 @@ def worst_patients_and_outliers(df: pd.DataFrame,
Prints a dataframe that contains the worst patients by the given metric, and a column indicating whether the
performance is so poor that it is considered an outlier: metric value which is outside of
outlier_range * standard deviation from the mean.
:param df: The dataframe with metrics.
:param outlier_range: The multiplier for standard deviation when constructing the interval for outliers.
:param metric_name: The metric for which the "worst" patients should be computed.
:param high_values_are_good: If True, high values for the metric are considered good, and hence low values
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
:param max_row_count: The maximum number of rows to print.
:return:
"""

Просмотреть файл

@ -69,6 +69,7 @@ def check_dataset_folder_exists(local_dataset: PathOrString) -> Path:
"""
Checks if a folder with a local dataset exists. If it does exist, return the argument converted to a Path instance.
If it does not exist, raise a FileNotFoundError.
:param local_dataset: The dataset folder to check.
:return: The local_dataset argument, converted to a Path.
"""
@ -83,6 +84,7 @@ def log_metrics(metrics: Dict[ModelExecutionMode, InferenceMetrics],
run_context: Run) -> None:
"""
Log metrics for each split to the provided run, or the current run context if None provided
:param metrics: Dictionary of inference results for each split.
:param run_context: Run for which to log the metrics to, use the current run context if None provided
"""
@ -111,20 +113,26 @@ class MLRunner:
"""
Driver class to run a ML experiment. Note that the project root argument MUST be supplied when using InnerEye
as a package!
:param model_config: If None, run the training as per the `container` argument (bring-your-own-model). If not
None, this is the model configuration for a built-in InnerEye model.
None, this is the model configuration for a built-in InnerEye model.
:param container: The LightningContainer object to use for training. If None, assume that the training is
for a built-in InnerEye model.
for a built-in InnerEye model.
:param azure_config: Azure related configurations
:param project_root: Project root. This should only be omitted if calling run_ml from the test suite. Supplying
it is crucial when using InnerEye as a package or submodule!
it is crucial when using InnerEye as a package or submodule!
:param post_cross_validation_hook: A function to call after waiting for completion of cross validation runs.
The function is called with the model configuration and the path to the downloaded and merged metrics files.
The function is called with the model configuration and the path to the downloaded and merged metrics files.
:param model_deployment_hook: an optional function for deploying a model in an application-specific way.
If present, it should take a LightningContainer, an AzureConfig, an AzureML Model and a ModelProcessing object
If present, it should take a LightningContainer, an AzureConfig, an AzureML Model and a ModelProcessing object
as arguments, and return an object of any type.
:param output_subfolder: If provided, the output folder structure will have an additional subfolder,
when running outside AzureML.
when running outside AzureML.
"""
self.model_config = model_config
if container is None:
@ -145,8 +153,9 @@ class MLRunner:
"""
If the present object is using one of the InnerEye built-in models, create a (fake) container for it
and call the setup method. It sets the random seeds, and then creates the actual Lightning modules.
:param azure_run_info: When running in AzureML or on a local VM, this contains the paths to the datasets.
This can be missing when running in unit tests, where the local dataset paths are already populated.
This can be missing when running in unit tests, where the local dataset paths are already populated.
"""
if self._has_setup_run:
return
@ -404,6 +413,7 @@ class MLRunner:
def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> None:
"""
Run inference on the test set for all models that are specified via a LightningContainer.
:param checkpoint_paths: The path to the checkpoint that should be used for inference.
"""
if len(checkpoint_paths) != 1:
@ -461,9 +471,10 @@ class MLRunner:
model_proc: ModelProcessing) -> None:
"""
Run inference on InnerEyeContainer models
:param checkpoint_paths: Checkpoint paths to initialize model
:param model_proc: whether we are running an ensemble model from within a child run with index 0. If we are,
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
then outputs will be written to OTHER_RUNS/ENSEMBLE under the main outputs directory.
"""
# run full image inference on existing or newly trained model on the training, and testing set
@ -515,10 +526,11 @@ class MLRunner:
"""
Registers a new model in the workspace's model registry on AzureML to be deployed further.
The AzureML run's tags are updated to describe with information about ensemble creation and the parent run ID.
:param checkpoint_paths: Checkpoint paths to register.
:param model_proc: whether it's a single or ensemble model.
:returns Tuple element 1: AML model object, or None if no model could be registered.
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
:return: Tuple element 1: AML model object, or None if no model could be registered.
Tuple element 2: The result of running the model_deployment_hook, or None if no hook was supplied.
"""
if self.is_offline_run:
raise ValueError("Cannot register models when InnerEye is running outside of AzureML.")
@ -601,9 +613,11 @@ class MLRunner:
extra_code_directory, and all checkpoints in a newly created "checkpoints" folder inside the model.
In addition, the name of the present AzureML Python environment will be written to a file, for later use
in the inference code.
:param model_folder: The folder into which all files should be copied.
:param checkpoint_paths: A list with absolute paths to checkpoint files. They are expected to be
inside of the model's checkpoint folder.
inside of the model's checkpoint folder.
:param python_environment: The Python environment that is used in the present AzureML run.
"""
@ -704,6 +718,7 @@ class MLRunner:
def wait_for_runs_to_finish(self, delay: int = 60) -> None:
"""
Wait for cross val runs (apart from the current one) to finish and then aggregate results of all.
:param delay: How long to wait between polls to AML to get status of child runs
"""
with logging_section("Waiting for sibling runs"):
@ -716,7 +731,7 @@ class MLRunner:
or cancelled.
:return: True if all sibling runs of the current run have finished (they either completed successfully,
or failed). False if any of them is still pending (running or queued).
or failed). False if any of them is still pending (running or queued).
"""
if (not self.is_offline_run) \
and (azure_util.is_cross_validation_child_run(RUN_CONTEXT)):

Просмотреть файл

@ -110,12 +110,14 @@ class Runner:
"""
This class contains the high-level logic to start a training run: choose a model configuration by name,
submit to AzureML if needed, or otherwise start the actual training and test loop.
:param project_root: The root folder that contains all of the source code that should be executed.
:param yaml_config_file: The path to the YAML file that contains values to supply into sys.argv.
:param post_cross_validation_hook: A function to call after waiting for completion of cross validation runs.
The function is called with the model configuration and the path to the downloaded and merged metrics files.
The function is called with the model configuration and the path to the downloaded and merged metrics files.
:param model_deployment_hook: an optional function for deploying a model in an application-specific way.
If present, it should take a model config (SegmentationModelBase), an AzureConfig, and an AzureML
If present, it should take a model config (SegmentationModelBase), an AzureConfig, and an AzureML
Model as arguments, and return an optional Path and a further object of any type.
"""
@ -203,7 +205,7 @@ class Runner:
The main entry point for training and testing models from the commandline. This chooses a model to train
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
:return: If submitting to AzureML, returns the model configuration that was used for training,
including commandline overrides applied (if any).
including commandline overrides applied (if any).
"""
# Usually, when we set logging to DEBUG, we want diagnostics about the model
# build itself, but not the tons of debug information that AzureML submissions create.
@ -378,8 +380,9 @@ class Runner:
def run_in_situ(self, azure_run_info: AzureRunInfo) -> None:
"""
Actually run the AzureML job; this method will typically run on an Azure VM.
:param azure_run_info: Contains all information about the present run in AzureML, in particular where the
datasets are mounted.
datasets are mounted.
"""
# Only set the logging level now. Usually, when we set logging to DEBUG, we want diagnostics about the model
# build itself, but not the tons of debug information that AzureML submissions create.
@ -423,6 +426,7 @@ class Runner:
def default_post_cross_validation_hook(config: ModelConfigBase, root_folder: Path) -> None:
"""
A function to run after cross validation results have been aggregated, before they are uploaded to AzureML.
:param config: The configuration of the model that should be trained.
:param root_folder: The folder with all aggregated and per-split files.
"""
@ -446,7 +450,7 @@ def run(project_root: Path,
The main entry point for training and testing models from the commandline. This chooses a model to train
via a commandline argument, runs training or testing, and writes all required info to disk and logs.
:return: If submitting to AzureML, returns the model configuration that was used for training,
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
including commandline overrides applied (if any). For details on the arguments, see the constructor of Runner.
"""
runner = Runner(project_root, yaml_config_file, post_cross_validation_hook, model_deployment_hook)
return runner.run()

Просмотреть файл

@ -89,10 +89,11 @@ class LabelTransformation(Enum):
def get_scaling_transform(max_value: int = 100, min_value: int = 0, last_in_pipeline: bool = True) -> Callable:
"""
Defines the function to scale labels.
:param max_value:
:param min_value:
:param last_in_pipeline: if the transformation is the last
in the pipeline it should expect a single label as an argument.
in the pipeline it should expect a single label as an argument.
Else if returns a list of scaled labels for further transforms.
:return: The scaling function
"""
@ -340,6 +341,7 @@ class ScalarModelBase(ModelConfigBase):
def filter_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Filter dataframes based on expected values on columns
:param df: the input dataframe
:return: the filtered dataframe
"""
@ -542,16 +544,15 @@ class ScalarModelBase(ModelConfigBase):
data_split: ModelExecutionMode) -> None:
"""
Computes all the metrics for a given (logits, labels) pair, and writes them to the loggers.
:param logits: The model output before normalization.
:param targets: The expected model outputs.
:param subject_ids: The subject IDs for the present minibatch.
:param is_training: If True, write the metrics as training metrics, otherwise as validation metrics.
:param metrics: A dictionary mapping from names of prediction targets to a list of metric computers,
as returned by create_metric_computers.
:param metrics: A dictionary mapping from names of prediction targets to a list of metric computers, as returned by create_metric_computers.
:param logger: An object of type DataframeLogger which can be be used for logging within this function.
:param current_epoch: Current epoch number.
:param data_split: ModelExecutionMode object indicating if this is the train or validation split.
:return:
"""
per_subject_outputs: List[Tuple[str, str, torch.Tensor, torch.Tensor]] = []
for i, (prediction_target, metric_list) in enumerate(metrics.items()):
@ -592,9 +593,10 @@ def get_non_image_features_dict(default_channels: List[str],
Returns the channels dictionary for non-imaging features.
:param default_channels: the channels to use for all features except the features specified
in specific_channels
in specific_channels
:param specific_channels: a dictionary mapping feature names to channels for all features that do
not use the default channels
not use the default channels
"""
non_imaging_features_dict = {DEFAULT_KEY: default_channels}
if specific_channels is not None:

Просмотреть файл

@ -43,13 +43,14 @@ def load_predictions(run_type: SurfaceDistanceRunType, azure_config: AzureConfig
) -> List[Segmentation]:
"""
For each run type (IOV or outliers), instantiate a list of predicted Segmentations and return
:param run_type: either "iov" or "outliers:
:param azure_config: AzureConfig
:param model_config: GenericConfig
:param execution_mode: ModelExecutionMode: Either Test, Train or Val
:param extended_annotators: List of annotators plus model_name to load segmentations for
:param outlier_range: The standard deviation from the mean which the points have to be below
to be considered an outlier.
to be considered an outlier.
:return: list of [(subject_id, structure name and dice_scores)]
"""
predictions = []

Просмотреть файл

@ -74,8 +74,9 @@ class CheckpointHandler:
Download checkpoints from a run recovery object or from a weights url. Set the checkpoints path based on the
run_recovery_object, weights_url or local_weights_path.
This is called at the start of training.
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
"""
if self.azure_config.run_recovery_id:
@ -246,6 +247,7 @@ def download_folder_from_run_to_temp_folder(folder: str,
:param run: If provided, download the files from that run. If omitted, download the files from the current run
(taken from RUN_CONTEXT)
:param workspace: The AML workspace where the run is located. If omitted, the hi-ml defaults of finding a workspace
are used (current workspace when running in AzureML, otherwise expecting a config.json file)
:return: The path to which the files were downloaded. The files are located in that folder, without any further
@ -276,6 +278,7 @@ def find_recovery_checkpoint_on_disk_or_cloud(path: Path) -> Optional[Path]:
"""
Looks at all the checkpoint files and returns the path to the one that should be used for recovery.
If no checkpoint files are found on disk, the function attempts to download from the current AzureML run.
:param path: The folder to start searching in.
:return: None if there is no suitable recovery checkpoints, or else a full path to the checkpoint file.
"""
@ -294,6 +297,7 @@ def get_recovery_checkpoint_path(path: Path) -> Path:
"""
Returns the path to the last recovery checkpoint in the given folder or the provided filename. Raises a
FileNotFoundError if no recovery checkpoint file is present.
:param path: Path to checkpoint folder
"""
recovery_checkpoint = find_recovery_checkpoint(path)
@ -308,6 +312,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
Finds the checkpoint file in the given path that can be used for re-starting the present job.
This can be an autosave checkpoint, or the last checkpoint. All existing checkpoints are loaded, and the one
for the highest epoch is used for recovery.
:param path: The folder to search in.
:return: Returns the checkpoint file to use for re-starting, or None if no such file was found.
"""
@ -337,6 +342,7 @@ def find_recovery_checkpoint(path: Path) -> Optional[Path]:
def cleanup_checkpoints(path: Path) -> None:
"""
Remove autosave checkpoints from the given checkpoint folder, and check if a "last.ckpt" checkpoint is present.
:param path: The folder that contains all checkpoint files.
"""
logging.info(f"Files in checkpoint folder: {' '.join(p.name for p in path.glob('*'))}")
@ -359,6 +365,7 @@ def download_best_checkpoints_from_child_runs(config: OutputParams, run: Run) ->
The checkpoints for the sibling runs will go into folder 'OTHER_RUNS/<cross_validation_split>'
in the checkpoint folder. There is special treatment for the child run that is equal to the present AzureML
run, its checkpoints will be read off the checkpoint folder as-is.
:param config: Model related configs.
:param run: The Hyperdrive parent run to download from.
:return: run recovery information
@ -392,12 +399,14 @@ def download_all_checkpoints_from_run(config: OutputParams, run: Run,
only_return_path: bool = False) -> RunRecovery:
"""
Downloads all checkpoints of the provided run inside the checkpoints folder.
:param config: Model related configs.
:param run: Run whose checkpoints should be recovered
:param subfolder: optional subfolder name, if provided the checkpoints will be downloaded to
CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
CHECKPOINT_FOLDER / subfolder. If None, the checkpoint are downloaded to CHECKPOINT_FOLDER of the current run.
:param: only_return_path: if True, return a RunRecovery object with the path to the checkpoint without actually
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
downloading the checkpoints. This is useful to avoid duplicating checkpoint download when running on multiple
nodes. If False, return the RunRecovery object and download the checkpoint to disk.
:return: run recovery information
"""

Просмотреть файл

@ -60,6 +60,7 @@ class ModelConfigLoader(GenericConfig):
Given a module specification check to see if it has a class property with
the <model_name> provided, and instantiate that config class with the
provided <config_overrides>. Otherwise, return None.
:param module_spec:
:return: Instantiated model config if it was found.
"""

Просмотреть файл

@ -39,7 +39,7 @@ def load_csv(csv_path: Path, expected_cols: List[str], col_type_converters: Opti
:param csv_path: Path to file
:param expected_cols: A list of the columns which must, as a minimum, be present.
:param col_type_converters: Dictionary of column: type, which ensures certain DataFrame columns are parsed with
specific types
specific types
:return: Loaded pandas DataFrame
"""
if not expected_cols:
@ -61,6 +61,7 @@ def drop_rows_missing_important_values(df: pd.DataFrame, important_cols: List[st
"""
Remove rows from the DataFrame in which the columns that have been specified by the user as "important" contain
null values or only whitespace.
:param df: DataFrame
:param important_cols: Columns which must not contain null values
:return: df: DataFrame without the dropped rows.
@ -83,9 +84,10 @@ def extract_outliers(df: pd.DataFrame, outlier_range: float, outlier_col: str =
:param df: DataFrame from which to extract the outliers
:param outlier_range: The number of standard deviation from the mean which the points have to be apart
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
mean + outlier_range * std (if outlier_type is HIGH) or below mean - outlier_range * std (if outlier_type is
LOW).
:param outlier_col: The column from which to calculate outliers, e.g. Dice
:param outlier_type: Either LOW (i.e. below accepted range) or HIGH (above accepted range) outliers.
:return: DataFrame containing only the outliers
@ -108,14 +110,16 @@ def mark_outliers(df: pd.DataFrame,
Rows that are not considered outliers have an empty string in the new column.
Outliers are taken from the column `outlier_col`, that have a value that falls outside of
mean +- outlier_range * std.
:param df: DataFrame from which to extract the outliers
:param outlier_range: The number of standard deviation from the mean which the points have to be apart
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
to be considered an outlier i.e. a point is considered an outlier if its outlier_col value is above
mean + outlier_range * std (if outlier_type is HIGH) or below mean - outlier_range * std (if outlier_type is
LOW).
:param outlier_col: The column from which to calculate outliers, e.g. Dice
:param high_values_are_good: If True, high values for the metric are considered good, and hence low values
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
are marked as outliers. If False, low values are considered good, and high values are marked as outliers.
:return: DataFrame with an additional column `is_outlier`
"""
if outlier_range < 0:
@ -137,10 +141,12 @@ def get_worst_performing_outliers(df: pd.DataFrame,
"""
Returns a sorted list (worst to best) of all the worst performing outliers in the metrics table
according to metric provided by outlier_col_name
:param df: Metrics DataFrame
:param outlier_col_name: The column by which to determine outliers
:param outlier_range: The standard deviation from the mean which the points have to be below
to be considered an outlier.
to be considered an outlier.
:param max_n_outliers: the number of (worst performing) outlier IDs to return.
:return: a sorted list (worst to best) of all the worst performing outliers
"""

Просмотреть файл

@ -31,7 +31,7 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
def __init__(self, columns_and_possible_categories: OrderedDict[str, List[str]]):
"""
:param columns_and_possible_categories: Mapping between dataset column names
to their possible values. eg: {'Inject': ['True', 'False']}. This is required
to their possible values. eg: {'Inject': ['True', 'False']}. This is required
to establish the one-hot encoding each of the possible values.
"""
super().__init__()
@ -75,7 +75,7 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
def get_supported_dataset_column_names(self) -> List[str]:
"""
:returns list of categorical columns that are supported by this encoder
:return: list of categorical columns that are supported by this encoder
"""
return list(self._columns_and_possible_categories.keys())
@ -86,8 +86,8 @@ class CategoricalToOneHotEncoder(OneHotEncoderBase):
of length 3.
:param feature_name: the name of the column for which to compute the feature
length.
:returns the feature length i.e. number of possible values for this feature.
length.
:return: the feature length i.e. number of possible values for this feature.
"""
return self._feature_length[feature_name]

Просмотреть файл

@ -24,21 +24,21 @@ class DeviceAwareModule(torch.nn.Module, Generic[T, E]):
def get_devices(self) -> List[torch.device]:
"""
:return a list of device ids on which this module
is deployed.
:return: a list of device ids on which this module
is deployed.
"""
return list({x.device for x in self.parameters()})
def get_number_trainable_parameters(self) -> int:
"""
:return the number of trainable parameters in the module.
:return: the number of trainable parameters in the module.
"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def is_model_on_gpu(self) -> bool:
"""
Checks if the model is cuda activated or not
:return True if the model is running on the GPU.
:return: True if the model is running on the GPU.
"""
try:
cuda_activated = next(self.parameters()).is_cuda
@ -51,6 +51,7 @@ class DeviceAwareModule(torch.nn.Module, Generic[T, E]):
"""
Extract the input tensors from a data sample as required
by the forward pass of the module.
:param item: a data sample
:return: the correct input tensors for the forward pass
"""

Просмотреть файл

@ -32,7 +32,7 @@ class FeatureStatistics:
:param sources: list of data sources
:return: a Feature Statistics object storing mean and standard deviation for each non-imaging feature of
the dataset.
the dataset.
"""
if len(sources) == 0:
raise ValueError("sources must have a length greater than 0")
@ -88,7 +88,7 @@ class FeatureStatistics:
All features that have zero standard deviation (constant features) are left untouched.
:param sources: list of datasources.
:return list of data sources where all non-imaging features are standardized.
:return: list of data sources where all non-imaging features are standardized.
"""
def apply_source(source: ScalarDataSource) -> ScalarDataSource:

Просмотреть файл

@ -72,9 +72,10 @@ class HDF5Object:
def parse_acquisition_date(date: str) -> Optional[datetime]:
"""
Converts a string representing a date to a datetime object
:param date: string representing a date
:return: converted date, None if the string is invalid for
date conversion.
date conversion.
"""
try:
return datetime.strptime(date, DATE_FORMAT)
@ -90,6 +91,7 @@ class HDF5Object:
def _load_image(hdf5_data: h5py.File, data_field: HDF5Field) -> np.ndarray:
"""
Load the volume from the HDF5 file.
:param hdf5_data: path to the hdf5 file
:param data_field: field of the hdf5 file containing the data
:return: image as numpy array

Просмотреть файл

@ -48,6 +48,7 @@ def get_unit_image_header(spacing: Optional[TupleFloat3] = None) -> ImageHeader:
"""
Creates an ImageHeader object with the origin at 0, and unit direction. The spacing is set to the argument,
defaulting to (1, 1, 1) if not provided.
:param spacing: The image spacing, as a (Z, Y, X) tuple.
"""
if not spacing:
@ -75,7 +76,7 @@ def apply_mask_to_posteriors(posteriors: NumpyOrTorch, mask: NumpyOrTorch) -> Nu
:param posteriors: image tensors in shape: Batches (optional) x Classes x Z x Y x X
:param mask: image tensor in shape: Batches (optional) x Z x Y x X
:return posteriors with mask applied
:return: posteriors with mask applied
"""
ml_util.check_size_matches(posteriors, mask, matching_dimensions=[-1, -2, -3])
@ -184,8 +185,9 @@ def _pad_images(images: np.ndarray,
:param images: the image(s) to be padded, in shape: Z x Y x X or batched in shape: Batches x Z x Y x X.
:param padding_vector: padding before and after in each dimension eg: ((2,2), (3,3), (2,0))
will pad 4 pixels in Z (2 on each side), 6 pixels in Y (3 on each side)
will pad 4 pixels in Z (2 on each side), 6 pixels in Y (3 on each side)
and 2 in X (2 on the left and 0 on the right).
:param padding_mode: a valid numpy padding mode.
:return: padded copy of the original image.
"""
@ -208,9 +210,9 @@ def posteriors_to_segmentation(posteriors: NumpyOrTorch) -> NumpyOrTorch:
Perform argmax on the class dimension.
:param posteriors: Confidence maps [0,1] for each patch per class in format: Batches x Class x Z x Y x X
or Class x Z x Y x X for non-batched input
:returns segmentation: argmaxed posteriors with each voxel belonging to a single class: Batches x Z x Y x X
or Z x Y x X for non-batched input
or Class x Z x Y x X for non-batched input
:return: segmentation: argmaxed posteriors with each voxel belonging to a single class: Batches x Z x Y x X
or Z x Y x X for non-batched input
"""
if posteriors is None:
@ -241,9 +243,11 @@ def largest_connected_components(img: np.ndarray,
Select the largest connected binary components (plural) in an image. If deletion_limit is set in which case a
component is only deleted (i.e. its voxels are False in the output) if its voxel count as a proportion of all the
True voxels in the input is less than deletion_limit.
:param img: np.ndarray
:param deletion_limit: if set, a component is deleted only if its voxel count as a proportion of all the
True voxels in the input is less than deletion_limit.
True voxels in the input is less than deletion_limit.
:param class_index: Optional. Can be used to provide a class index for logging purposes if the image contains
only pixels from a specific class.
"""
@ -281,9 +285,10 @@ def extract_largest_foreground_connected_component(
restrictions: Optional[List[Tuple[int, Optional[float]]]] = None) -> np.ndarray:
"""
Extracts the largest foreground connected component per class from a multi-label array.
:param multi_label_array: An array of class assignments, i.e. value c at (z, y, x) is a class c.
:param restrictions: restrict processing to a subset of the classes (if provided). Each element is a
pair (class_index, threshold) where threshold may be None.
pair (class_index, threshold) where threshold may be None.
:return: An array of class assignments
"""
if restrictions is None:
@ -311,6 +316,7 @@ def merge_masks(masks: np.ndarray) -> np.ndarray:
"""
Merges a one-hot encoded mask tensor (Classes x Z x Y x X) into a multi-label map with labels corresponding to their
index in the original tensor of shape (Z x Y x X).
:param masks: array of shape (Classes x Z x Y x X) containing the mask for each class
:return: merged_mask of shape (Z x Y x X).
"""
@ -346,7 +352,7 @@ def multi_label_array_to_binary(array: np.ndarray, num_classes_including_backgro
:param array: An array of class assignments.
:param num_classes_including_background: The number of class assignments to search for. If 3 classes,
the class assignments to search for will be 0, 1, and 2.
the class assignments to search for will be 0, 1, and 2.
:return: an array of size (num_classes_including_background, array.shape)
"""
return np.stack(list(binaries_from_multi_label_array(array, num_classes_including_background)))
@ -368,7 +374,7 @@ def get_center_crop(image: NumpyOrTorch, crop_shape: TupleInt3) -> NumpyOrTorch:
:param image: The original image to extract crop from
:param crop_shape: The shape of the center crop to extract
:return the center region as specified by the crop_shape argument.
:return: the center region as specified by the crop_shape argument.
"""
if image is None or crop_shape is None:
raise Exception("image and crop_shape must not be None")
@ -399,6 +405,7 @@ def check_array_range(data: np.ndarray, expected_range: Optional[Range] = None,
:param data: The array to check. It can have any size.
:param expected_range: The interval that all array elements must fall into. The first entry is the lower
bound, the second entry is the upper bound.
:param error_prefix: A string to use as the prefix for the error message.
"""
if expected_range is None:
@ -498,9 +505,9 @@ def compute_uncertainty_map_from_posteriors(posteriors: np.ndarray) -> np.ndarra
Normalized Shannon Entropy: https://en.wiktionary.org/wiki/Shannon_entropy
:param posteriors: Normalized probability distribution in range [0, 1] for each class,
in shape: Class x Z x Y x X
in shape: Class x Z x Y x X
:return: Shannon Entropy for each voxel, shape: Z x Y x X expected range is [0,1] where 1 represents
low confidence or uniform posterior distribution across classes.
low confidence or uniform posterior distribution across classes.
"""
check_if_posterior_array(posteriors)
@ -513,10 +520,11 @@ def gaussian_smooth_posteriors(posteriors: np.ndarray, kernel_size_mm: TupleFloa
Performs Gaussian smoothing on posteriors
:param posteriors: Normalized probability distribution in range [0, 1] for each class,
in shape: Class x Z x Y x X
in shape: Class x Z x Y x X
:param kernel_size_mm: The size of the smoothing kernel in mm to be used in each dimension (Z, Y, X)
:param voxel_spacing_mm: Voxel spacing to use to map from mm space to pixel space for the
Gaussian sigma parameter for each dimension in (Z x Y x X) order.
Gaussian sigma parameter for each dimension in (Z x Y x X) order.
:return:
"""
check_if_posterior_array(posteriors)
@ -557,10 +565,11 @@ def segmentation_to_one_hot(segmentation: torch.Tensor,
:param segmentation: A segmentation as a multi-label map of shape [B, C, Z, Y, X]
:param use_gpu: If true, and the input is not yet on the GPU, move the intermediate tensors to the GPU. The result
will be on the same device as the argument `segmentation`
will be on the same device as the argument `segmentation`
:param result_dtype: The torch data type that the result tensor should have. This would be either float16 or float32
:return: A torch tensor with one-hot encoding of the segmentation of shape
[B, C*HDF5_NUM_SEGMENTATION_CLASSES, Z, Y, X]
[B, C*HDF5_NUM_SEGMENTATION_CLASSES, Z, Y, X]
"""
def to_cuda(x: torch.Tensor) -> torch.Tensor:
@ -702,6 +711,7 @@ def apply_summed_probability_rules(model_config: SegmentationModelBase,
:param model_config: Model configuration information
:param posteriors: Confidences per voxel per class, in format Batch x Classes x Z x Y x X if batched,
or Classes x Z x Y x X if not batched.
:param segmentation: Class labels per voxel, in format Batch x Z x Y x X if batched, or Z x Y x X if not batched.
:return: Modified segmentation, as Batch x Z x Y x X if batched, or Z x Y x X if not batched.
"""

Просмотреть файл

@ -207,6 +207,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
:param path: The path to the image to load.
:return: A numpy array of the image and header data if applicable.
:param image_type: The type to load the image in, set to None to not cast, default is float
:raises ValueError: If the path is invalid or the image is not 3D.
"""
@ -214,6 +215,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
def _is_valid_image_path(_path: Path) -> bool:
"""
Validates a path for an image. Image must be .nii, or .nii.gz.
:param _path: The path to the file.
:return: True if it is valid, False otherwise
"""
@ -241,6 +243,7 @@ def load_nifti_image(path: PathOrString, image_type: Optional[Type] = float) ->
def load_numpy_image(path: PathOrString, image_type: Optional[Type] = None) -> np.ndarray:
"""
Loads an array from a numpy file (npz or npy). The array is converted to image_type or untouched if None
:param path: The path to the numpy file.
:param image_type: The dtype to cast the array
:return: ndarray
@ -258,6 +261,7 @@ def load_numpy_image(path: PathOrString, image_type: Optional[Type] = None) -> n
def load_dicom_image(path: PathOrString) -> np.ndarray:
"""
Loads an array from a single dicom file.
:param path: The path to the dicom file.
"""
ds = dicom.dcmread(path)
@ -278,6 +282,7 @@ def load_dicom_image(path: PathOrString) -> np.ndarray:
def load_hdf5_dataset_from_file(path_str: Path, dataset_name: str) -> np.ndarray:
"""
Loads a hdf5 dataset from a file as an ndarray
:param path_str: The path to the HDF5 file
:param dataset_name: The dataset name in the HDF5 file that we want to load
:return: ndarray
@ -292,15 +297,17 @@ def load_hdf5_dataset_from_file(path_str: Path, dataset_name: str) -> np.ndarray
def load_hdf5_file(path_str: Union[str, Path], load_segmentation: bool = False) -> HDF5Object:
"""
Loads a single HDF5 file.
:param path_str: The path of the HDF5 file that should be loaded.
:param load_segmentation: If True, the `segmentation` field of the result object will be populated. If
False, the field will be set to None.
False, the field will be set to None.
:return: HDF5Object
"""
def _is_valid_hdf5_path(_path: Path) -> bool:
"""
Validates a path for an image
:param _path:
:return:
"""
@ -331,12 +338,14 @@ def load_images_and_stack(files: Iterable[Path],
:param files: The paths of the files to load.
:param load_segmentation: If True it loads segmentation if present on the same file as the image. This is only
supported for loading from HDF5 files.
supported for loading from HDF5 files.
:param center_crop_size: If supplied, all loaded images will be cropped to the size given here. The crop will be
taken from the center of the image.
taken from the center of the image.
:param image_size: If supplied, all loaded images will be resized immediately after loading.
:return: A wrapper class that contains the loaded images, and if load_segmentation is True, also the segmentations
that were present in the files.
that were present in the files.
"""
images = []
segmentations = []
@ -420,6 +429,7 @@ def load_labels_from_dataset_source(dataset_source: PatientDatasetSource, check_
In the future, this function will be used to load global class and non-imaging information as well.
:type image_size: Image size, tuple of integers.
:param dataset_source: The dataset source for which channels are to be loaded into memory.
:param check_exclusive: Check that the labels are mutually exclusive (defaults to True).
:return: A label sample object containing ground-truth information.
@ -467,6 +477,7 @@ def load_image(path: PathOrString, image_type: Optional[Type] = float) -> ImageW
For segmentation binary |<dataset_name>|<channel index>
For segmentation multimap |<dataset_name>|<channel index>|<multimap value>
The expected dimensions to be (channel, Z, Y, X)
:param path: The path to the file
:param image_type: The type of the image
"""
@ -625,7 +636,7 @@ def store_binary_mask_as_nifti(image: np.ndarray, header: ImageHeader, file_name
:param header: The image header
:param file_name: The name of the file for this image.
:return: the path to the saved image
:raises: when image is not binary
:raises Exception: when image is not binary
"""
if not is_binary_array(image):
raise Exception("Array values must be binary.")

Просмотреть файл

@ -38,14 +38,17 @@ def get_padding_from_kernel_size(padding: PaddingMode,
Returns padding value required for convolution layers based on input kernel size and dilation.
:param padding: Padding type (Enum) {`zero`, `no_padding`}. Option `zero` is intended to preserve the tensor shape.
In `no_padding` option, padding is not applied and the function returns only zeros.
In `no_padding` option, padding is not applied and the function returns only zeros.
:param kernel_size: Spatial support of the convolution kernel. It is used to determine the padding size. This can be
a scalar, tuple or array.
a scalar, tuple or array.
:param dilation: Dilation of convolution kernel. It is used to determine the padding size. This can be a scalar,
tuple or array.
tuple or array.
:param num_dimensions: The number of dimensions that the returned padding tuple should have, if both
kernel_size and dilation are scalars.
:return padding value required for convolution layers based on input kernel size and dilation.
kernel_size and dilation are scalars.
:return: padding value required for convolution layers based on input kernel size and dilation.
"""
if isinstance(kernel_size, Sized):
num_dimensions = len(kernel_size)
@ -70,8 +73,9 @@ def get_upsampling_kernel_size(downsampling_factor: IntOrTuple3, num_dimensions:
https://distill.pub/2016/deconv-checkerboard/
:param downsampling_factor: downsampling factor use for each dimension of the kernel. Can be
either a list of len(num_dimension) with one factor per dimension or an int in which case the
either a list of len(num_dimension) with one factor per dimension or an int in which case the
same factor will be applied for all dimension.
:param num_dimensions: number of dimensions of the kernel
:return: upsampling_kernel_size
"""
@ -98,6 +102,7 @@ def set_model_to_eval_mode(model: torch.nn.Module) -> Generator:
"""
Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to
what is was before the call.
:param model: The model to modify.
"""
old_mode = model.training

Просмотреть файл

@ -54,6 +54,7 @@ class MetricsPerPatientWriter:
"""
Writes the per-patient per-structure metrics to a CSV file.
Sorting is done first by structure name, then by Dice score ascending.
:param file_name: The name of the file to write to.
"""
df = self.to_data_frame()
@ -70,7 +71,7 @@ class MetricsPerPatientWriter:
:param file_path: The name of the file to write to.
:param allow_incomplete_labels: boolean flag. If false, all ground truth files must be provided.
If true, ground truth files are optional and we add a total_patients count column for easy
If true, ground truth files are optional and we add a total_patients count column for easy
comparison. (Defaults to False.)
"""
@ -130,6 +131,7 @@ class MetricsPerPatientWriter:
def get_number_of_voxels_per_class(labels: torch.Tensor) -> torch.Tensor:
"""
Computes the number of voxels for each class in a one-hot label map.
:param labels: one-hot label map in shape Batches x Classes x Z x Y x X or Classes x Z x Y x X
:return: A tensor of shape [Batches x Classes] containing the number of non-zero voxels along Z, Y, X
"""
@ -208,9 +210,9 @@ def binary_classification_accuracy(model_output: Union[torch.Tensor, np.ndarray]
:param model_output: A tensor containing model outputs.
:param label: A tensor containing class labels.
:param threshold: the cut-off probability threshold for predictions. If model_ouput is > threshold, the predicted
class is 1 else 0.
class is 1 else 0.
:return: 1.0 if all predicted classes match the expected classes given in 'labels'. 0.0 if no predicted classes
match their labels.
match their labels.
"""
model_output, label = convert_input_and_label(model_output, label)
predicted_class = model_output > threshold
@ -254,7 +256,7 @@ def convert_input_and_label(model_output: Union[torch.Tensor, np.ndarray],
label: Union[torch.Tensor, np.ndarray]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Ensures that both model_output and label are tensors of dtype float32.
:return a Tuple with model_output, label as float tensors.
:return: a Tuple with model_output, label as float tensors.
"""
if not torch.is_tensor(model_output):
model_output = torch.tensor(model_output)
@ -268,8 +270,9 @@ def is_missing_ground_truth(ground_truth: np.ndarray) -> bool:
calculate_metrics_per_class in metrics.py and plot_contours_for_all_classes in plotting.py both
check whether there is ground truth missing using this simple check for NaN value at 0, 0, 0.
To avoid duplicate code we bring it here as a utility function.
:param ground_truth: ground truth binary array with dimensions: [Z x Y x X].
:param label_id: Integer index of the label to check.
:returns: True if the label is missing (signified by NaN), False otherwise.
:return: True if the label is missing (signified by NaN), False otherwise.
"""
return np.isnan(ground_truth[0, 0, 0])

Просмотреть файл

@ -82,8 +82,8 @@ def validate_dataset_paths(
Validates that the required dataset csv file exists in the given path.
:param dataset_path: The base path
:param custom_dataset_csv : The name of the dataset csv file
:raise ValueError if the dataset does not exist.
:param custom_dataset_csv: The name of the dataset csv file
:raises: ValueError if the dataset does not exist.
"""
if not dataset_path.is_dir():
raise ValueError("The dataset_path argument should be the path to the base directory of the data "
@ -107,14 +107,17 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
:param arg1: The first array to check.
:param arg2: The second array to check.
:param dim1: The expected number of dimensions of arg1. If zero, no check for number of dimensions will be
conducted.
conducted.
:param dim2: The expected number of dimensions of arg2. If zero, no check for number of dimensions will be
conducted.
conducted.
:param matching_dimensions: The dimensions along which the two arguments have to match. For example, if
arg1.ndim==4 and arg2.ndim==5, matching_dimensions==[3] checks if arg1.shape[3] == arg2.shape[3].
arg1.ndim==4 and arg2.ndim==5, matching_dimensions==[3] checks if arg1.shape[3] == arg2.shape[3].
:param arg1_name: If provided, all error messages will use that string to instead of "arg1"
:param arg2_name: If provided, all error messages will use that string to instead of "arg2"
:raise ValueError if shapes don't match
:raises: ValueError if shapes don't match
"""
if arg1 is None or arg2 is None:
raise Exception("arg1 and arg2 cannot be None.")
@ -125,10 +128,11 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
def check_dim(expected: int, actual_shape: Any, name: str) -> None:
"""
Check if actual_shape is equal to the expected shape
:param expected: expected shape
:param actual_shape:
:param name: variable name
:raise ValueError if not the same shape
:raises: ValueError if not the same shape
"""
if len(actual_shape) != expected:
raise ValueError("'{}' was expected to have ndim == {}, but is {}. Shape is {}"
@ -151,6 +155,7 @@ def check_size_matches(arg1: Union[np.ndarray, torch.Tensor],
def set_random_seed(random_seed: int, caller_name: Optional[str] = None) -> None:
"""
Set the seed for the random number generators of python, numpy, torch.random, and torch.cuda for all gpus.
:param random_seed: random seed value to set.
:param caller_name: name of the caller for logging purposes.
"""
@ -170,8 +175,9 @@ def is_test_from_execution_mode(execution_mode: ModelExecutionMode) -> bool:
"""
Returns a boolean by checking the execution type. The output is used to determine the properties
of the forward pass, e.g. model gradient updates or metric computation.
:return True if execution mode is VAL or TEST, False if TRAIN
:raise ValueError if the execution mode is invalid
:return: True if execution mode is VAL or TEST, False if TRAIN
:raises ValueError: if the execution mode is invalid
"""
if execution_mode == ModelExecutionMode.TRAIN:
return False
@ -207,6 +213,6 @@ def is_tensor_nan(tensor: torch.Tensor) -> bool:
:param tensor: The tensor to check.
:return: True if any of the tensor elements is Not a Number, False if all entries are valid numbers.
If the tensor is empty, the function returns False.
If the tensor is empty, the function returns False.
"""
return bool(torch.isnan(tensor).any().item())

Просмотреть файл

@ -11,6 +11,7 @@ from InnerEye.Common.type_annotations import TupleInt3
def random_colour(rng: random.Random) -> TupleInt3:
"""
Generates a random colour in RGB given a random number generator
:param rng: Random number generator
:return: Tuple with random colour in RGB
"""
@ -23,6 +24,7 @@ def random_colour(rng: random.Random) -> TupleInt3:
def generate_random_colours_list(rng: random.Random, size: int) -> List[TupleInt3]:
"""
Generates a list of random colours in RGB given a random number generator and the size of this list
:param rng: random number generator
:param size: size of the list
:return: list of random colours in RGB

Просмотреть файл

@ -203,6 +203,7 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware
"""
Writes a human readable summary of the present model to logging.info, and logs the number of trainable
parameters to AzureML.
:param config: The configuration for the model.
:param model: The instantiated Pytorch model.
"""
@ -264,6 +265,7 @@ class ScalarModelInputsAndLabels():
def move_to_device(self, device: Union[str, torch.device]) -> None:
"""
Moves the model_inputs and labels field of the present object to the given device. This is done in-place.
:param device: The target device.
"""
self.model_inputs = [t.to(device=device) for t in self.model_inputs]
@ -274,13 +276,15 @@ def get_scalar_model_inputs_and_labels(model: torch.nn.Module,
sample: Dict[str, Any]) -> ScalarModelInputsAndLabels:
"""
For a model that predicts scalars, gets the model input tensors from a sample returned by the data loader.
:param model: The instantiated PyTorch model.
:param target_indices: If this list is non-empty, assume that the model is a sequence model, and build the
model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
model inputs and labels for a model that predicts those specific positions in the sequence. If the list is empty,
assume that the model is a normal (non-sequence) model.
:param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value)
:return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors,
label tensor, subject IDs, and the data item reconstructed from the data loader output
label tensor, subject IDs, and the data item reconstructed from the data loader output
"""
scalar_model: DeviceAwareModule[ScalarItem, torch.Tensor] = model # type: ignore
scalar_item = ScalarItem.from_dict(sample)

Просмотреть файл

@ -14,6 +14,7 @@ def get_view_dim_and_origin(plane: Plane) -> Tuple[int, str]:
"""
Get the axis along which to slice, as well as the orientation of the origin, to ensure images
are plotted as expected
:param plane: the plane in which to plot (i.e. axial, sagittal or coronal)
:return:
"""
@ -33,6 +34,7 @@ def get_cropped_axes(image: np.ndarray, boundary_width: int = 5) -> Tuple[slice,
"""
Return the min and max values on both x and y axes where the image is not empty
Method: find the min and max of all non-zero pixels in the image, and add a border
:param image: the image to be cropped
:param boundary_width: number of pixels boundary to add around bounding box
:return:

Просмотреть файл

@ -50,10 +50,11 @@ def sequences_to_padded_tensor(sequences: List[torch.Tensor],
padding_value: float = 0.0) -> torch.Tensor:
"""
Method to convert possibly unequal length sequences to a padded tensor.
:param sequences: List of Tensors to pad
:param padding_value: Padding value to use, default is 0.0
:return: Output tensor with shape B x * where * is the max dimensions from the list of provided tensors.
And B is the number of tensors in the list of sequences provided.
And B is the number of tensors in the list of sequences provided.
"""
return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
@ -80,6 +81,7 @@ def get_masked_model_outputs_and_labels(model_output: torch.Tensor,
Helper function to get masked model outputs, labels and their associated subject ids. Masking is performed
by excluding the NaN model outputs and labels based on a bool mask created using the
occurrences of NaN in the labels provided.
:param model_output: The model output tensor to mask.
:param labels: The label tensor to use for mask, and use for masking.
:param subject_ids: The associated subject ids.
@ -125,6 +127,7 @@ def apply_sequence_model_loss(loss_fn: torch.nn.Module,
"""
Applies a loss function to a model output and labels, when the labels come from sequences with unequal length.
Missing sequence elements are masked out.
:param loss_fn: The loss function to apply to the sequence elements that are present.
:param model_output: The model outputs
:param labels: The ground truth labels.

Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше