In this PR:
- Fix is introduced to handle missing cross-validation rounds (e.g., when a round fails)
- Extra validation epoch is handled
- Downloading the AML metrics json and saving it as a dataframe is separated into two functions (current version gives a `KeyError` when the json is downloaded for the first time)

Co-authored-by: Harshita Sharma <t-hsharma@microsoft.com>
This commit is contained in:
Harshita Sharma 2022-11-07 09:39:09 +00:00 коммит произвёл GitHub
Родитель 41d30807aa
Коммит 6ba3cfe685
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
5 изменённых файлов: 171 добавлений и 100 удалений

Просмотреть файл

@ -18,8 +18,9 @@ from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_TE
AML_VAL_OUTPUTS_CSV) AML_VAL_OUTPUTS_CSV)
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs, from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
child_runs_have_val_and_test_outputs, get_best_epoch_metrics, child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_hyperdrive_metrics_table, get_formatted_run_info, get_best_epochs, get_child_runs_hyperparams, get_hyperdrive_metrics_table,
collect_class_info) get_formatted_run_info, collect_class_info, get_max_epochs,
download_hyperdrive_metrics_if_required)
from health_cpath.utils.naming import MetricsKey, ModelKey from health_cpath.utils.naming import MetricsKey, ModelKey
@ -53,10 +54,16 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
report.add_heading("Azure ML metrics", level=2) report.add_heading("Azure ML metrics", level=2)
# Download metrics from AML. Can take several seconds for each child run # Download metrics from AML. Can take several seconds for each child run
metrics_df = collect_hyperdrive_metrics( metrics_json = download_hyperdrive_metrics_if_required(parent_run_id, report_dir, aml_workspace,
parent_run_id, report_dir, aml_workspace, overwrite=overwrite, hyperdrive_arg_name=hyperdrive_arg_name overwrite=overwrite, hyperdrive_arg_name=hyperdrive_arg_name)
)
best_epochs = get_best_epochs(metrics_df, f'{ModelKey.VAL}/{primary_metric}', maximise=True) # Get metrics dataframe from the downloaded json file
metrics_df = collect_hyperdrive_metrics(metrics_json=metrics_json)
hyperparameters_children = get_child_runs_hyperparams(metrics_df)
max_epochs_dict = get_max_epochs(hyperparams_children=hyperparameters_children)
best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric=f'{ModelKey.VAL}/{primary_metric}',
max_epochs_dict=max_epochs_dict, maximise=True)
# Add training curves for loss and AUROC (train and val.) # Add training curves for loss and AUROC (train and val.)
render_training_curves(report, heading="Training curves", level=3, render_training_curves(report, heading="Training curves", level=3,
@ -64,7 +71,7 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
primary_metric=primary_metric) primary_metric=primary_metric)
# Get metrics list with class names # Get metrics list with class names
num_classes, class_names = collect_class_info(metrics_df=metrics_df) num_classes, class_names = collect_class_info(hyperparams_children=hyperparameters_children)
base_metrics_list: List[str] = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.AVERAGE_PRECISION, base_metrics_list: List[str] = [MetricsKey.ACC, MetricsKey.AUROC, MetricsKey.AVERAGE_PRECISION,
MetricsKey.COHENKAPPA] MetricsKey.COHENKAPPA]
@ -87,56 +94,58 @@ def generate_html_report(parent_run_id: str, output_dir: Path,
metrics_df=metrics_df, best_epochs=None, metrics_df=metrics_df, best_epochs=None,
base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.TEST}/') base_metrics_list=base_metrics_list, metrics_prefix=f'{ModelKey.TEST}/')
has_val_and_test_outputs = child_runs_have_val_and_test_outputs(parent_run) # Get output data frames if available
try:
# Get output data frames has_val_and_test_outputs = child_runs_have_val_and_test_outputs(parent_run)
if has_val_and_test_outputs:
output_filename_val = AML_VAL_OUTPUTS_CSV
outputs_dfs_val = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_val, overwrite=overwrite,
hyperdrive_arg_name=hyperdrive_arg_name)
if include_test:
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
outputs_dfs_test = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_test, overwrite=overwrite,
hyperdrive_arg_name=hyperdrive_arg_name)
if num_classes == 1:
# Currently ROC and PR curves rendered only for binary case
# TODO: Enable rendering of multi-class ROC and PR curves
report.add_heading("ROC and PR curves", level=2)
if has_val_and_test_outputs: if has_val_and_test_outputs:
# Add val. ROC and PR curves output_filename_val = AML_VAL_OUTPUTS_CSV
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3, outputs_dfs_val = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
report_dir=report_dir, aml_workspace=aml_workspace,
outputs_dfs=outputs_dfs_val, output_filename=output_filename_val, overwrite=overwrite,
prefix=f'{ModelKey.VAL}_') hyperdrive_arg_name=hyperdrive_arg_name)
if include_test:
output_filename_test = AML_TEST_OUTPUTS_CSV if has_val_and_test_outputs else AML_LEGACY_TEST_OUTPUTS_CSV
outputs_dfs_test = collect_hyperdrive_outputs(parent_run_id=parent_run_id, download_dir=report_dir,
aml_workspace=aml_workspace,
output_filename=output_filename_test, overwrite=overwrite,
hyperdrive_arg_name=hyperdrive_arg_name)
if num_classes == 1:
# Currently ROC and PR curves rendered only for binary case
# TODO: Enable rendering of multi-class ROC and PR curves
report.add_heading("ROC and PR curves", level=2)
if has_val_and_test_outputs:
# Add val. ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Validation ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test:
# Add test ROC and PR curves
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3,
report_dir=report_dir,
outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
# Add confusion matrices for each fold
report.add_heading("Confusion matrices", level=2)
if has_val_and_test_outputs:
# Add val. confusion matrices
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test: if include_test:
# Add test ROC and PR curves # Add test confusion matrices
render_roc_and_pr_curves(report=report, heading="Test ROC and PR curves", level=3, render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
report_dir=report_dir, class_names=class_names,
outputs_dfs=outputs_dfs_test, report_dir=report_dir, outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_') prefix=f'{ModelKey.TEST}_')
# Add confusion matrices for each fold except ValueError as e:
print(e)
report.add_heading("Confusion matrices", level=2) print("Since all expected output files were not found, skipping ROC-PR curves and confusion matrices.")
if has_val_and_test_outputs:
# Add val. confusion matrices
render_confusion_matrices(report=report, heading="Validation confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_val,
prefix=f'{ModelKey.VAL}_')
if include_test:
# Add test confusion matrices
render_confusion_matrices(report=report, heading="Test confusion matrices", level=3,
class_names=class_names,
report_dir=report_dir, outputs_dfs=outputs_dfs_test,
prefix=f'{ModelKey.TEST}_')
# TODO: Add qualitative model outputs # TODO: Add qualitative model outputs
# report.add_heading("Qualitative model outputs", level=2) # report.add_heading("Qualitative model outputs", level=2)

Просмотреть файл

@ -229,14 +229,17 @@ def plot_hyperdrive_training_curves(metrics_df: pd.DataFrame, train_metric: str,
for k in sorted(metrics_df.columns): for k in sorted(metrics_df.columns):
train_values = metrics_df.loc[train_metric, k] train_values = metrics_df.loc[train_metric, k]
val_values = metrics_df.loc[val_metric, k] val_values = metrics_df.loc[val_metric, k]
line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Child {k}") if train_values is not None:
color = line.get_color() line, = ax.plot(train_values, **TRAIN_STYLE, label=f"Child {k}")
ax.plot(val_values, color=color, **VAL_STYLE) color = line.get_color()
if val_values is not None:
ax.plot(val_values, color=color, **VAL_STYLE)
if best_epochs is not None: if best_epochs is not None:
best_epoch = best_epochs[k] best_epoch = best_epochs[k]
ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE) if best_epoch is not None:
ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE) ax.plot(best_epoch, train_values[best_epoch], color=color, zorder=1000, **BEST_TRAIN_MARKER_STYLE)
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE) ax.plot(best_epoch, val_values[best_epoch], color=color, zorder=1000, **BEST_VAL_MARKER_STYLE)
ax.axvline(best_epoch, color=color, **BEST_EPOCH_LINE_STYLE)
ax.grid(color='0.9') ax.grid(color='0.9')
ax.set_xlabel("Epoch") ax.set_xlabel("Epoch")
if ylabel: if ylabel:

Просмотреть файл

@ -89,6 +89,7 @@ class AMLMetricsJsonKey(str, Enum):
VALUE = 'value' VALUE = 'value'
N_CLASSES = 'n_classes' N_CLASSES = 'n_classes'
CLASS_NAMES = 'class_names' CLASS_NAMES = 'class_names'
MAX_EPOCHS = 'max_epochs'
class PlotOption(Enum): class PlotOption(Enum):

Просмотреть файл

@ -4,7 +4,7 @@
# ------------------------------------------------------------------------------------------- # -------------------------------------------------------------------------------------------
from pathlib import Path from pathlib import Path
from typing import Dict, List, Sequence, Tuple from typing import Any, Dict, List, Sequence, Tuple
import dateutil.parser import dateutil.parser
import numpy as np import numpy as np
@ -92,10 +92,10 @@ def collect_hyperdrive_outputs(parent_run_id: str, download_dir: Path, aml_works
return dict(sorted(all_outputs_dfs.items())) # type: ignore return dict(sorted(all_outputs_dfs.items())) # type: ignore
def collect_hyperdrive_metrics(parent_run_id: str, download_dir: Path, aml_workspace: Workspace, def download_hyperdrive_metrics_if_required(parent_run_id: str, download_dir: Path, aml_workspace: Workspace,
hyperdrive_arg_name: str = "crossval_index", hyperdrive_arg_name: str = "crossval_index",
overwrite: bool = False) -> pd.DataFrame: overwrite: bool = False) -> Path:
"""Fetch metrics logged to Azure ML from hyperdrive runs as a dataframe. """Fetch metrics logged to Azure ML from hyperdrive runs.
Will only download the metrics if they do not already exist locally, as this can take several Will only download the metrics if they do not already exist locally, as this can take several
seconds for each child run. seconds for each child run.
@ -105,12 +105,11 @@ def collect_hyperdrive_metrics(parent_run_id: str, download_dir: Path, aml_works
:param aml_workspace: Azure ML workspace in which the runs were executed. :param aml_workspace: Azure ML workspace in which the runs were executed.
:param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs. :param hyperdrive_arg_name: Name of the Hyperdrive argument used for indexing the child runs.
:param overwrite: Whether to force the download even if metrics are already saved locally. :param overwrite: Whether to force the download even if metrics are already saved locally.
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`. :return: The path of the downloaded json file.
""" """
metrics_json = download_dir / "aml_metrics.json" metrics_json = download_dir / "aml_metrics.json"
if not overwrite and metrics_json.is_file(): if not overwrite and metrics_json.is_file():
print(f"AML metrics file already exists at {metrics_json}") print(f"AML metrics file already exists at {metrics_json}")
metrics_df = pd.read_json(metrics_json)
else: else:
metrics_df = aggregate_hyperdrive_metrics(run_id=parent_run_id, metrics_df = aggregate_hyperdrive_metrics(run_id=parent_run_id,
child_run_arg_name=hyperdrive_arg_name, child_run_arg_name=hyperdrive_arg_name,
@ -118,7 +117,17 @@ def collect_hyperdrive_metrics(parent_run_id: str, download_dir: Path, aml_works
metrics_json.parent.mkdir(parents=True, exist_ok=True) metrics_json.parent.mkdir(parents=True, exist_ok=True)
print(f"Writing AML metrics file to {metrics_json}") print(f"Writing AML metrics file to {metrics_json}")
df_to_json(metrics_df, metrics_json) df_to_json(metrics_df, metrics_json)
return metrics_df.sort_index(axis='columns') return metrics_json
def collect_hyperdrive_metrics(metrics_json: Path) -> pd.DataFrame:
"""
Collect the hyperdrive metrics from the downloaded metrics json file in a dataframe.
:param metrics_json: Path of the downloaded metrics file `aml_metrics.json`.
:return: A dataframe in the format returned by :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
"""
metrics_df = pd.read_json(metrics_json).sort_index(axis='columns')
return metrics_df
def get_hyperdrive_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame: def get_hyperdrive_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequence[str]) -> pd.DataFrame:
@ -139,13 +148,16 @@ def get_hyperdrive_metrics_table(metrics_df: pd.DataFrame, metrics_list: Sequenc
values: pd.Series = metrics_df.loc[metric] values: pd.Series = metrics_df.loc[metric]
mean = values.mean() mean = values.mean()
std = values.std() std = values.std()
row = [metric] + [f"{v:.3f}" for v in values] + [f"{mean:.3f} ± {std:.3f}"] round_values: List[str] = [f"{v:.3f}" if v is not None else str(np.nan) for v in values]
agg_values: List[str] = [f"{mean:.3f} ± {std:.3f}"]
row = [metric] + round_values + agg_values
metrics_rows.append(row) metrics_rows.append(row)
table = pd.DataFrame(metrics_rows, columns=header).set_index(header[0]) table = pd.DataFrame(metrics_rows, columns=header).set_index(header[0])
return table return table
def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, maximise: bool = True) -> Dict[int, int]: def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, max_epochs_dict: Dict[int, int],
maximise: bool = True) -> Dict[int, Any]:
"""Determine the best epoch for each hyperdrive child run based on a given metric. """Determine the best epoch for each hyperdrive child run based on a given metric.
The returned epoch indices are relative to the logging frequency of the chosen metric, i.e. The returned epoch indices are relative to the logging frequency of the chosen metric, i.e.
@ -154,16 +166,26 @@ def get_best_epochs(metrics_df: pd.DataFrame, primary_metric: str, maximise: boo
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and :param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`. :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:param primary_metric: Name of the reference metric to optimise. :param primary_metric: Name of the reference metric to optimise.
:max_epochs_dict: A dictionary of the maximum number of epochs in each cross-validation round.
:param maximise: Whether the given metric should be maximised (minimised if `False`). :param maximise: Whether the given metric should be maximised (minimised if `False`).
:return: Dictionary mapping each hyperdrive child index to its best epoch. :return: Dictionary mapping each hyperdrive child index to its best epoch.
""" """
best_fn = np.argmax if maximise else np.argmin best_epochs: Dict[int, Any] = {}
best_epochs = metrics_df.loc[primary_metric].apply(best_fn) for child_index in metrics_df.columns:
return best_epochs.to_dict() primary_metric_list = metrics_df[child_index][primary_metric]
if primary_metric_list is not None:
# If extra validation epoch was logged (N+1), return only the first N elements
primary_metric_list = primary_metric_list[:-1] \
if (len(primary_metric_list) == max_epochs_dict[child_index] + 1) else primary_metric_list
best_epochs[child_index] = int(np.argmax(primary_metric_list)
if maximise else np.argmin(primary_metric_list))
else:
best_epochs[child_index] = None
return best_epochs
def get_best_epoch_metrics(metrics_df: pd.DataFrame, metrics_list: Sequence[str], def get_best_epoch_metrics(metrics_df: pd.DataFrame, metrics_list: Sequence[str],
best_epochs: Dict[int, int]) -> pd.DataFrame: best_epochs: Dict[int, Any]) -> pd.DataFrame:
"""Extract the values of the selected hyperdrive metrics at the given best epochs. """Extract the values of the selected hyperdrive metrics at the given best epochs.
The `best_epoch` indices are relative to the logging frequency of the chosen primary metric, The `best_epoch` indices are relative to the logging frequency of the chosen primary metric,
@ -179,7 +201,7 @@ def get_best_epoch_metrics(metrics_df: pd.DataFrame, metrics_list: Sequence[str]
containing only scalar values. containing only scalar values.
""" """
best_metrics = [metrics_df.loc[metrics_list, k].apply(lambda values: values[epoch]) best_metrics = [metrics_df.loc[metrics_list, k].apply(lambda values: values[epoch])
for k, epoch in best_epochs.items()] if epoch is not None else metrics_df.loc[metrics_list, k] for k, epoch in best_epochs.items()]
best_metrics_df = pd.DataFrame(best_metrics).T best_metrics_df = pd.DataFrame(best_metrics).T
return best_metrics_df return best_metrics_df
@ -216,20 +238,30 @@ def get_formatted_run_info(parent_run: Run) -> str:
return html return html
def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]: def get_child_runs_hyperparams(metrics_df: pd.DataFrame) -> Dict[int, Dict]:
""" """
Get the class names from metrics dataframe Get the hyperparameters of each child run from the metrics dataframe.
:param metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and :param: metrics_df: Metrics dataframe, as returned by :py:func:`collect_hyperdrive_metrics()` and
:py:func:`~health_azure.aggregate_hyperdrive_metrics()`. :py:func:`~health_azure.aggregate_hyperdrive_metrics()`.
:return: Number of classes and list of class names :return: A dictionary of hyperparameter dictionaries for the child runs.
""" """
hyperparams = metrics_df[0][AMLMetricsJsonKey.HYPERPARAMS] hyperparams_children = {}
hyperparams_name = hyperparams[AMLMetricsJsonKey.NAME] for child_index in metrics_df.columns:
hyperparams_value = hyperparams[AMLMetricsJsonKey.VALUE] hyperparams = metrics_df[child_index][AMLMetricsJsonKey.HYPERPARAMS]
num_classes_index = hyperparams_name.index(AMLMetricsJsonKey.N_CLASSES) hyperparams_dict = dict(zip(hyperparams[AMLMetricsJsonKey.NAME], hyperparams[AMLMetricsJsonKey.VALUE]))
num_classes = int(hyperparams_value[num_classes_index]) hyperparams_children[child_index] = hyperparams_dict
class_names_index = hyperparams_name.index(AMLMetricsJsonKey.CLASS_NAMES) return hyperparams_children
class_names = hyperparams_value[class_names_index]
def collect_class_info(hyperparams_children: Dict[int, Dict]) -> Tuple[int, List[str]]:
"""
Get the class names from the hyperparameters of child runs.
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
:return: Number of classes and list of class names.
"""
hyperparams_single_run = list(hyperparams_children.values())[0]
num_classes = int(hyperparams_single_run[AMLMetricsJsonKey.N_CLASSES])
class_names = hyperparams_single_run[AMLMetricsJsonKey.CLASS_NAMES]
if class_names == "None": if class_names == "None":
class_names = None class_names = None
else: else:
@ -237,3 +269,15 @@ def collect_class_info(metrics_df: pd.DataFrame) -> Tuple[int, List[str]]:
class_names = [name.lstrip() for name in class_names[1:-1].replace("'", "").split(',')] class_names = [name.lstrip() for name in class_names[1:-1].replace("'", "").split(',')]
class_names = validate_class_names(class_names=class_names, n_classes=num_classes) class_names = validate_class_names(class_names=class_names, n_classes=num_classes)
return (num_classes, list(class_names)) return (num_classes, list(class_names))
def get_max_epochs(hyperparams_children: Dict[int, Dict]) -> Dict[int, int]:
"""
Get the maximum number of epochs for each round from the metrics dataframe.
:param hyperparams_children: Dict of hyperparameter dicts, as returned by :py:func:`get_child_runs_hyperparams()`.
:return: Dictionary with the number of epochs in each hyperdrive run.
"""
max_epochs_dict = {}
for child_index in hyperparams_children.keys():
max_epochs_dict[child_index] = int(hyperparams_children[child_index][AMLMetricsJsonKey.MAX_EPOCHS])
return max_epochs_dict

Просмотреть файл

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Sequence, Union from typing import Any, Dict, List, Sequence, Union
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
@ -14,7 +14,7 @@ from health_cpath.utils.output_utils import (AML_LEGACY_TEST_OUTPUTS_CSV, AML_OU
from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs, from health_cpath.utils.report_utils import (collect_hyperdrive_metrics, collect_hyperdrive_outputs,
child_runs_have_val_and_test_outputs, get_best_epoch_metrics, child_runs_have_val_and_test_outputs, get_best_epoch_metrics,
get_best_epochs, get_hyperdrive_metrics_table, get_best_epochs, get_hyperdrive_metrics_table,
run_has_val_and_test_outputs) run_has_val_and_test_outputs, download_hyperdrive_metrics_if_required)
def test_run_has_val_and_test_outputs() -> None: def test_run_has_val_and_test_outputs() -> None:
@ -176,13 +176,24 @@ def metrics_df() -> pd.DataFrame:
'val/auroc': [0.8, 0.9, 0.7], 'val/auroc': [0.8, 0.9, 0.7],
'test/accuracy': 0.9, 'test/accuracy': 0.9,
'test/auroc': 0.9 'test/auroc': 0.9
} },
4: {'val/accuracy': None,
'val/auroc': None,
'test/accuracy': None,
'test/auroc': None
}
}) })
@pytest.fixture @pytest.fixture
def best_epochs(metrics_df: pd.DataFrame) -> Dict[int, int]: def max_epochs_dict() -> Dict[int, int]:
return get_best_epochs(metrics_df, 'val/accuracy', maximise=True) return {0: 3, 1: 10, 3: 3, 4: 3}
@pytest.fixture
def best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int]) -> Dict[int, Any]:
return get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
max_epochs_dict=max_epochs_dict, maximise=True)
@pytest.fixture @pytest.fixture
@ -195,13 +206,15 @@ def best_epoch_metrics(metrics_df: pd.DataFrame, best_epochs: Dict[int, int]) ->
def test_collect_hyperdrive_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None: def test_collect_hyperdrive_metrics(metrics_df: pd.DataFrame, tmp_path: Path, overwrite: bool) -> None:
with patch('health_cpath.utils.report_utils.aggregate_hyperdrive_metrics', with patch('health_cpath.utils.report_utils.aggregate_hyperdrive_metrics',
return_value=metrics_df) as mock_aggregate: return_value=metrics_df) as mock_aggregate:
returned_df = collect_hyperdrive_metrics(parent_run_id="", download_dir=tmp_path, returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite) aml_workspace=None, overwrite=overwrite)
returned_df = collect_hyperdrive_metrics(metrics_json=returned_json)
mock_aggregate.assert_called_once() mock_aggregate.assert_called_once()
mock_aggregate.reset_mock() mock_aggregate.reset_mock()
new_returned_df = collect_hyperdrive_metrics(parent_run_id="", download_dir=tmp_path, new_returned_json = download_hyperdrive_metrics_if_required(parent_run_id="", download_dir=tmp_path,
aml_workspace=None, overwrite=overwrite) aml_workspace=None, overwrite=overwrite)
new_returned_df = collect_hyperdrive_metrics(metrics_json=new_returned_json)
if overwrite: if overwrite:
mock_aggregate.assert_called_once() mock_aggregate.assert_called_once()
else: else:
@ -211,12 +224,13 @@ def test_collect_hyperdrive_metrics(metrics_df: pd.DataFrame, tmp_path: Path, ov
@pytest.mark.parametrize('maximise', [True, False]) @pytest.mark.parametrize('maximise', [True, False])
def test_get_best_epochs(metrics_df: pd.DataFrame, maximise: bool) -> None: def test_get_best_epochs(metrics_df: pd.DataFrame, max_epochs_dict: Dict[int, int], maximise: bool) -> None:
best_epochs = get_best_epochs(metrics_df, 'val/accuracy', maximise=maximise) best_epochs = get_best_epochs(metrics_df=metrics_df, primary_metric='val/accuracy',
max_epochs_dict=max_epochs_dict, maximise=maximise)
assert list(best_epochs.keys()) == list(metrics_df.columns) assert list(best_epochs.keys()) == list(metrics_df.columns)
assert all(isinstance(epoch, int) for epoch in best_epochs.values()) assert all(isinstance(epoch, (int, type(None))) for epoch in best_epochs.values())
expected_best = {0: 0, 1: 1, 3: 2} if maximise else {0: 1, 1: 2, 3: 0} expected_best = {0: 0, 1: 1, 3: 2, 4: None} if maximise else {0: 1, 1: 2, 3: 0, 4: None}
for split in metrics_df.columns: for split in metrics_df.columns:
assert best_epochs[split] == expected_best[split] assert best_epochs[split] == expected_best[split]
@ -244,4 +258,4 @@ def test_get_hyperdrive_metrics_table(
original_values = df.loc[metrics_list].values original_values = df.loc[metrics_list].values
table_values = metrics_table.iloc[:, :-1].applymap(float).values table_values = metrics_table.iloc[:, :-1].applymap(float).values
assert (table_values == original_values).all() assert (table_values[~pd.isnull(table_values)] == original_values[~pd.isnull(original_values)]).all()