This PR contains changes to the class CheckpointDownloader to fix a bug in the local path of the downloaded checkpoint
This commit is contained in:
vale-salvatelli 2021-12-21 14:24:14 +00:00 коммит произвёл GitHub
Родитель 484a0a8b28
Коммит 7902a6e138
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 25 добавлений и 10 удалений

Просмотреть файл

@ -43,6 +43,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
- ([#161](https://github.com/microsoft/hi-ml/pull/161)) Empty string as target folder for a dataset creates an invalid mounting path for the dataset in AzureML (fixes #160)
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) Fix bugs in logging hyperparameters: logging as name/value
table, rather than one column per hyperparameter. Use string logging for all hyperparameters
- ([#174](https://github.com/microsoft/hi-ml/pull/174)) Fix bugs in returned local_checkpoint_path when downloading checkpoints from AML run
### Removed

Просмотреть файл

@ -467,7 +467,7 @@ class RunIdOrListParam(CustomTypeParam):
class CheckpointDownloader:
def __init__(self, run_id: str, checkpoint_filename: str, azure_config_json_path: Path = None,
aml_workspace: Workspace = None, download_dir: PathOrString = "checkpoints",
remote_checkpoint_folder: PathOrString = "checkpoints") -> None:
remote_checkpoint_dir: PathOrString = "checkpoints") -> None:
"""
Utility class for downloading checkpoint files from an Azure ML run
@ -480,25 +480,29 @@ class CheckpointDownloader:
:param aml_workspace: An optional Azure ML Workspace object. If not running inside an AML Run, and no
azure_config_json_path is provided, this is required.
:param download_dir: The local directory in which to save the downloaded checkpoint files.
:param remote_checkpoint_folder: The remote folder from which to download the checkpoint file
:param remote_checkpoint_dir: The remote folder from which to download the checkpoint file
"""
self.azure_config_json_path = azure_config_json_path
self.aml_workspace = aml_workspace
self.run_id = run_id
self.checkpoint_filename = checkpoint_filename
self.download_dir = Path(download_dir)
self.remote_checkpoint_folder = Path(remote_checkpoint_folder)
self.remote_checkpoint_dir = Path(remote_checkpoint_dir)
@property
def local_checkpoint_path(self) -> Path:
def local_checkpoint_dir(self) -> Path:
# in case we run_id is a run recovery id, extract the run id
run_id_parts = self.run_id.split(":")
run_id = run_id_parts[-1]
return self.download_dir / run_id / self.checkpoint_filename
return self.download_dir / run_id
@property
def remote_checkpoint_path(self) -> Path:
return self.remote_checkpoint_folder / self.checkpoint_filename
return self.remote_checkpoint_dir / self.checkpoint_filename
@property
def local_checkpoint_path(self) -> Path:
return self.local_checkpoint_dir / self.remote_checkpoint_path
def download_checkpoint_if_necessary(self) -> Path:
"""Downloads the specified checkpoint if it does not already exist.
@ -509,9 +513,9 @@ class CheckpointDownloader:
workspace_config_path=self.azure_config_json_path)
if not self.local_checkpoint_path.exists():
local_checkpoint_dir = self.local_checkpoint_path.parent
local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path), local_checkpoint_dir,
self.local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path),
self.local_checkpoint_dir,
aml_workspace=workspace)
assert self.local_checkpoint_path.exists()
@ -608,9 +612,12 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
else:
raise ValueError("No workspace config file given, nor can we find one.")
if workspace_config_path.is_file():
if not isinstance(workspace_config_path, Path):
raise ValueError("Workspace config path is not a path, check your input.")
elif workspace_config_path.is_file():
auth = get_authentication()
return Workspace.from_config(path=str(workspace_config_path), auth=auth)
raise ValueError("Workspace config file does not exist or cannot be read.")
@ -1046,6 +1053,7 @@ def get_run_file_names(run: Run, prefix: str = "") -> List[str]:
:return: A list of paths within the Run's container
"""
all_files = run.get_file_names()
print(f"Selecting files with prefix {prefix}")
return [f for f in all_files if f.startswith(prefix)] if prefix else all_files

Просмотреть файл

@ -116,11 +116,17 @@ def test_get_workspace(
with pytest.raises(ValueError) as ex:
util.get_workspace(None, None)
assert "No workspace config file given" in str(ex)
# Workspace config file is set to a file that does not exist
with pytest.raises(ValueError) as ex:
util.get_workspace(None, workspace_config_path=tmp_path / "does_not_exist")
assert "Workspace config file does not exist" in str(ex)
# Workspace config file is set to a wrong type
with pytest.raises(ValueError) as ex:
util.get_workspace(None, workspace_config_path=1) # type: ignore
assert "Workspace config path is not a path" in str(ex)
@patch("health_azure.utils.Run")
def test_create_run_recovery_id(mock_run: MagicMock) -> None: