зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add account key argument to generate checkpoint URL script (#874)
Some datastores don't have access to account key, so we need a way of specifying it via commandline
This commit is contained in:
Родитель
122cd12215
Коммит
f7e91a6a51
|
@ -10,6 +10,7 @@ from typing import Optional
|
|||
def get_checkpoint_url_from_aml_run(
|
||||
run_id: str,
|
||||
checkpoint_filename: str,
|
||||
account_key: str,
|
||||
expiry_days: int = 1,
|
||||
workspace_config_path: Optional[Path] = None,
|
||||
aml_workspace: Optional[Workspace] = None,
|
||||
|
@ -19,6 +20,7 @@ def get_checkpoint_url_from_aml_run(
|
|||
|
||||
:param run_id: The run ID of the checkpoint.
|
||||
:param checkpoint_filename: The filename of the checkpoint.
|
||||
:param account_key: The Azure Storage account key to use for the SAS token.
|
||||
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
|
||||
:param workspace_config_path: The path to the workspace config file, defaults to None.
|
||||
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
|
||||
|
@ -26,17 +28,16 @@ def get_checkpoint_url_from_aml_run(
|
|||
:return: The SAS URL for the checkpoint.
|
||||
"""
|
||||
workspace = get_workspace(aml_workspace=aml_workspace, workspace_config_path=workspace_config_path)
|
||||
datastore = workspace.get_default_datastore()
|
||||
account_name = datastore.account_name
|
||||
account_name = workspace.get_details()['storageAccount'].split('/')[-1]
|
||||
print(f"Workspace {workspace.name} stores its run results in storage account {account_name}.")
|
||||
container_name = 'azureml'
|
||||
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
|
||||
|
||||
if not sas_token:
|
||||
sas_token = generate_blob_sas(
|
||||
account_name=datastore.account_name,
|
||||
account_name=account_name,
|
||||
container_name=container_name,
|
||||
blob_name=blob_name,
|
||||
account_key=datastore.account_key,
|
||||
account_key=account_key,
|
||||
permission=BlobSasPermissions(read=True),
|
||||
expiry=datetime.utcnow() + timedelta(days=expiry_days),
|
||||
)
|
||||
|
@ -62,6 +63,9 @@ if __name__ == '__main__':
|
|||
default=30,
|
||||
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--account_key', default='', type=str, help='The Azure Storage account key to use for the SAS token.'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
workspace_config_path = Path(args.workspace_config) if args.workspace_config else None
|
||||
url = get_checkpoint_url_from_aml_run(
|
||||
|
@ -69,5 +73,6 @@ if __name__ == '__main__':
|
|||
checkpoint_filename=args.checkpoint_filename,
|
||||
expiry_days=args.expiry_days,
|
||||
workspace_config_path=workspace_config_path,
|
||||
account_key=args.account_key,
|
||||
)
|
||||
print(f'Checkpoint URL: {url}')
|
||||
|
|
|
@ -42,11 +42,13 @@ def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
|
|||
|
||||
|
||||
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
|
||||
aml_workspace = DEFAULT_WORKSPACE.workspace
|
||||
blob_url = get_checkpoint_url_from_aml_run(
|
||||
run_id=TEST_SSL_RUN_ID,
|
||||
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
|
||||
expiry_days=1,
|
||||
aml_workspace=DEFAULT_WORKSPACE.workspace,
|
||||
aml_workspace=aml_workspace,
|
||||
account_key=aml_workspace.get_default_datastore().account_key,
|
||||
)
|
||||
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
|
||||
assert encoder_params.ssl_checkpoint.is_url
|
||||
|
|
Загрузка…
Ссылка в новой задаче