зеркало из https://github.com/microsoft/hi-ml.git
Fixing checkpoint downloader (#174)
This PR contains changes to the class CheckpointDownloader to fix a bug in the local path of the downloaded checkpoint
This commit is contained in:
Родитель
484a0a8b28
Коммит
7902a6e138
|
@ -43,6 +43,7 @@ the section headers (Added/Changed/...) and incrementing the package version.
|
|||
- ([#161](https://github.com/microsoft/hi-ml/pull/161)) Empty string as target folder for a dataset creates an invalid mounting path for the dataset in AzureML (fixes #160)
|
||||
- ([#167](https://github.com/microsoft/hi-ml/pull/167)) Fix bugs in logging hyperparameters: logging as name/value
|
||||
table, rather than one column per hyperparameter. Use string logging for all hyperparameters
|
||||
- ([#174](https://github.com/microsoft/hi-ml/pull/174)) Fix bugs in returned local_checkpoint_path when downloading checkpoints from AML run
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -467,7 +467,7 @@ class RunIdOrListParam(CustomTypeParam):
|
|||
class CheckpointDownloader:
|
||||
def __init__(self, run_id: str, checkpoint_filename: str, azure_config_json_path: Path = None,
|
||||
aml_workspace: Workspace = None, download_dir: PathOrString = "checkpoints",
|
||||
remote_checkpoint_folder: PathOrString = "checkpoints") -> None:
|
||||
remote_checkpoint_dir: PathOrString = "checkpoints") -> None:
|
||||
"""
|
||||
Utility class for downloading checkpoint files from an Azure ML run
|
||||
|
||||
|
@ -480,25 +480,29 @@ class CheckpointDownloader:
|
|||
:param aml_workspace: An optional Azure ML Workspace object. If not running inside an AML Run, and no
|
||||
azure_config_json_path is provided, this is required.
|
||||
:param download_dir: The local directory in which to save the downloaded checkpoint files.
|
||||
:param remote_checkpoint_folder: The remote folder from which to download the checkpoint file
|
||||
:param remote_checkpoint_dir: The remote folder from which to download the checkpoint file
|
||||
"""
|
||||
self.azure_config_json_path = azure_config_json_path
|
||||
self.aml_workspace = aml_workspace
|
||||
self.run_id = run_id
|
||||
self.checkpoint_filename = checkpoint_filename
|
||||
self.download_dir = Path(download_dir)
|
||||
self.remote_checkpoint_folder = Path(remote_checkpoint_folder)
|
||||
self.remote_checkpoint_dir = Path(remote_checkpoint_dir)
|
||||
|
||||
@property
|
||||
def local_checkpoint_path(self) -> Path:
|
||||
def local_checkpoint_dir(self) -> Path:
|
||||
# in case we run_id is a run recovery id, extract the run id
|
||||
run_id_parts = self.run_id.split(":")
|
||||
run_id = run_id_parts[-1]
|
||||
return self.download_dir / run_id / self.checkpoint_filename
|
||||
return self.download_dir / run_id
|
||||
|
||||
@property
|
||||
def remote_checkpoint_path(self) -> Path:
|
||||
return self.remote_checkpoint_folder / self.checkpoint_filename
|
||||
return self.remote_checkpoint_dir / self.checkpoint_filename
|
||||
|
||||
@property
|
||||
def local_checkpoint_path(self) -> Path:
|
||||
return self.local_checkpoint_dir / self.remote_checkpoint_path
|
||||
|
||||
def download_checkpoint_if_necessary(self) -> Path:
|
||||
"""Downloads the specified checkpoint if it does not already exist.
|
||||
|
@ -509,9 +513,9 @@ class CheckpointDownloader:
|
|||
workspace_config_path=self.azure_config_json_path)
|
||||
|
||||
if not self.local_checkpoint_path.exists():
|
||||
local_checkpoint_dir = self.local_checkpoint_path.parent
|
||||
local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
|
||||
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path), local_checkpoint_dir,
|
||||
self.local_checkpoint_dir.mkdir(exist_ok=True, parents=True)
|
||||
download_checkpoints_from_run_id(self.run_id, str(self.remote_checkpoint_path),
|
||||
self.local_checkpoint_dir,
|
||||
aml_workspace=workspace)
|
||||
assert self.local_checkpoint_path.exists()
|
||||
|
||||
|
@ -608,9 +612,12 @@ def get_workspace(aml_workspace: Optional[Workspace] = None, workspace_config_pa
|
|||
else:
|
||||
raise ValueError("No workspace config file given, nor can we find one.")
|
||||
|
||||
if workspace_config_path.is_file():
|
||||
if not isinstance(workspace_config_path, Path):
|
||||
raise ValueError("Workspace config path is not a path, check your input.")
|
||||
elif workspace_config_path.is_file():
|
||||
auth = get_authentication()
|
||||
return Workspace.from_config(path=str(workspace_config_path), auth=auth)
|
||||
|
||||
raise ValueError("Workspace config file does not exist or cannot be read.")
|
||||
|
||||
|
||||
|
@ -1046,6 +1053,7 @@ def get_run_file_names(run: Run, prefix: str = "") -> List[str]:
|
|||
:return: A list of paths within the Run's container
|
||||
"""
|
||||
all_files = run.get_file_names()
|
||||
print(f"Selecting files with prefix {prefix}")
|
||||
return [f for f in all_files if f.startswith(prefix)] if prefix else all_files
|
||||
|
||||
|
||||
|
|
|
@ -116,11 +116,17 @@ def test_get_workspace(
|
|||
with pytest.raises(ValueError) as ex:
|
||||
util.get_workspace(None, None)
|
||||
assert "No workspace config file given" in str(ex)
|
||||
|
||||
# Workspace config file is set to a file that does not exist
|
||||
with pytest.raises(ValueError) as ex:
|
||||
util.get_workspace(None, workspace_config_path=tmp_path / "does_not_exist")
|
||||
assert "Workspace config file does not exist" in str(ex)
|
||||
|
||||
# Workspace config file is set to a wrong type
|
||||
with pytest.raises(ValueError) as ex:
|
||||
util.get_workspace(None, workspace_config_path=1) # type: ignore
|
||||
assert "Workspace config path is not a path" in str(ex)
|
||||
|
||||
|
||||
@patch("health_azure.utils.Run")
|
||||
def test_create_run_recovery_id(mock_run: MagicMock) -> None:
|
||||
|
|
Загрузка…
Ссылка в новой задаче