ENH: Add account key argument to generate checkpoint URL script (#874)

Some datastores don't have access to account key, so we need a way of
specifying it via commandline
This commit is contained in:
Kenza Bouzid 2023-04-20 08:51:16 +01:00 коммит произвёл GitHub
Родитель 122cd12215
Коммит f7e91a6a51
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 13 добавлений и 6 удалений

Просмотреть файл

@ -10,6 +10,7 @@ from typing import Optional
def get_checkpoint_url_from_aml_run(
run_id: str,
checkpoint_filename: str,
account_key: str,
expiry_days: int = 1,
workspace_config_path: Optional[Path] = None,
aml_workspace: Optional[Workspace] = None,
@ -19,6 +20,7 @@ def get_checkpoint_url_from_aml_run(
:param run_id: The run ID of the checkpoint.
:param checkpoint_filename: The filename of the checkpoint.
:param account_key: The Azure Storage account key to use for the SAS token.
:param expiry_days: The number of days the SAS URL is valid for, defaults to 30.
:param workspace_config_path: The path to the workspace config file, defaults to None.
:param aml_workspace: The Azure ML workspace to use, defaults to the default workspace.
@ -26,17 +28,16 @@ def get_checkpoint_url_from_aml_run(
:return: The SAS URL for the checkpoint.
"""
workspace = get_workspace(aml_workspace=aml_workspace, workspace_config_path=workspace_config_path)
datastore = workspace.get_default_datastore()
account_name = datastore.account_name
account_name = workspace.get_details()['storageAccount'].split('/')[-1]
print(f"Workspace {workspace.name} stores its run results in storage account {account_name}.")
container_name = 'azureml'
blob_name = f'ExperimentRun/dcid.{run_id}/{DEFAULT_AML_CHECKPOINT_DIR}/{checkpoint_filename}'
if not sas_token:
sas_token = generate_blob_sas(
account_name=datastore.account_name,
account_name=account_name,
container_name=container_name,
blob_name=blob_name,
account_key=datastore.account_key,
account_key=account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(days=expiry_days),
)
@ -62,6 +63,9 @@ if __name__ == '__main__':
default=30,
help='The number of hours for which the SAS token is valid. Default: 30 for 1 month',
)
parser.add_argument(
'--account_key', default='', type=str, help='The Azure Storage account key to use for the SAS token.'
)
args = parser.parse_args()
workspace_config_path = Path(args.workspace_config) if args.workspace_config else None
url = get_checkpoint_url_from_aml_run(
@ -69,5 +73,6 @@ if __name__ == '__main__':
checkpoint_filename=args.checkpoint_filename,
expiry_days=args.expiry_days,
workspace_config_path=workspace_config_path,
account_key=args.account_key,
)
print(f'Checkpoint URL: {url}')

Просмотреть файл

@ -42,11 +42,13 @@ def test_load_ssl_checkpoint_from_local_file(tmp_path: Path) -> None:
def test_load_ssl_checkpoint_from_url(tmp_path: Path) -> None:
aml_workspace = DEFAULT_WORKSPACE.workspace
blob_url = get_checkpoint_url_from_aml_run(
run_id=TEST_SSL_RUN_ID,
checkpoint_filename=LAST_CHECKPOINT_FILE_NAME,
expiry_days=1,
aml_workspace=DEFAULT_WORKSPACE.workspace,
aml_workspace=aml_workspace,
account_key=aml_workspace.get_default_datastore().account_key,
)
encoder_params = EncoderParams(encoder_type=SSLEncoder.__name__, ssl_checkpoint=CheckpointParser(blob_url))
assert encoder_params.ssl_checkpoint.is_url