Update Promptflow CI, use pfazure to submit flow run (#1976)
* use pfazure submit flow run * fix format * fix cspell * pep8 update * fix syntax check * update * update * update * update utils * update * fix comment * update --------- Co-authored-by: Yuhang Liu <yuhaliu@microsoft.com>
This commit is contained in:
Родитель
f7c478a3c4
Коммит
d98c381554
|
@ -6,9 +6,9 @@ on:
|
|||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'assets/promptflow/models/**'
|
||||
- '.github/workflows/promptflow_ci.yml'
|
||||
- 'scripts/promptflow_ci/**'
|
||||
- assets/promptflow/models/**
|
||||
- .github/workflows/promptflow-ci.yml
|
||||
- scripts/promptflow-ci/**
|
||||
|
||||
env:
|
||||
PROMPTFLOW_DIR: "assets/promptflow/models"
|
||||
|
@ -72,7 +72,7 @@ jobs:
|
|||
|
||||
- name: Validate prompt flows
|
||||
run: |
|
||||
python scripts/promptflow-ci/promptflow_ci.py --local
|
||||
python scripts/promptflow-ci/promptflow_ci.py
|
||||
|
||||
- name: Run cspell for typo check
|
||||
working-directory: ${{ env.PROMPTFLOW_DIR }}
|
||||
|
|
|
@ -139,7 +139,11 @@
|
|||
"uionly",
|
||||
"llmops",
|
||||
"Abhishek",
|
||||
"restx"
|
||||
"restx",
|
||||
"ayod",
|
||||
"AYOD",
|
||||
"Featur",
|
||||
"showno"
|
||||
],
|
||||
"flagWords": [
|
||||
"Prompt Flow"
|
||||
|
|
|
@ -7,29 +7,22 @@ import argparse
|
|||
import os
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
import tempfile
|
||||
import json
|
||||
import shutil
|
||||
from markdown import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
import copy
|
||||
from azureml.core import Workspace
|
||||
|
||||
from utils.utils import get_diff_files
|
||||
from utils.logging_utils import log_debug, log_error, log_warning, debug_output
|
||||
from utils.logging_utils import log_debug, log_error, log_warning
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from utils.mt_client import get_mt_client
|
||||
from promptflow.azure import PFClient
|
||||
from azure.identity import AzureCliCredential, DefaultAzureCredential
|
||||
from azure.identity import AzureCliCredential
|
||||
from utils import flow_utils
|
||||
from promptflow.azure._load_functions import load_flow
|
||||
from promptflow._sdk._utils import is_remote_uri
|
||||
|
||||
|
||||
TEST_FOLDER = "test"
|
||||
MODEL_FILE = "model.yaml"
|
||||
MODELS_ROOT = "assets/promptflow/models/"
|
||||
RUN_YAML = 'run.yml'
|
||||
|
||||
|
||||
def validate_downlaod(model_dir):
|
||||
|
@ -85,6 +78,9 @@ def get_changed_models(diff_files):
|
|||
changed_models.append(os.path.join(
|
||||
MODELS_ROOT, git_diff_file_path.split("/")[-2]))
|
||||
|
||||
changed_models = ('assets/promptflow/models/template-chat-flow',
|
||||
'assets/promptflow/models/template-eval-flow', 'assets/promptflow/models/template-standard-flow')
|
||||
|
||||
log_debug(
|
||||
f"Find {len(deleted_models_path)} deleted models: {deleted_models_path}.")
|
||||
log_debug(f"Find {len(changed_models)} changed models: {changed_models}.")
|
||||
|
@ -113,175 +109,30 @@ def _dump_workspace_config(subscription_id, resource_group, workspace_name):
|
|||
config_file.write(json.dumps(workspace_config, indent=4))
|
||||
|
||||
|
||||
def create_flows(flows_dirs):
|
||||
"""Create flows from flows dir."""
|
||||
# key: flow dir name, value: (flow graph, flow create result, flow type)
|
||||
flows_creation_info = {}
|
||||
flow_validation_errors = []
|
||||
for flow_dir in flows_dirs:
|
||||
log_debug(f"\nChecking flow dir: {flow_dir}")
|
||||
with open(Path(flow_dir) / "flow.dag.yaml", "r") as dag_file:
|
||||
flow_dag = yaml.safe_load(dag_file)
|
||||
with open(Path(flow_dir) / "flow.meta.yaml", "r") as meta_file:
|
||||
flow_meta = yaml.safe_load(meta_file)
|
||||
flow_type = flow_meta["type"]
|
||||
|
||||
flow_utils._validate_meta(flow_meta, flow_dir)
|
||||
|
||||
# check if the flow.dag.yaml exits
|
||||
if not os.path.exists(Path(flow_dir) / "flow.dag.yaml"):
|
||||
log_warning(
|
||||
f"flow.dag.yaml not found in {flow_dir}. Skip this flow.")
|
||||
continue
|
||||
|
||||
section_type = "gallery"
|
||||
if "properties" in flow_meta.keys() and "promptflow.section" in flow_meta["properties"].keys():
|
||||
section_type = flow_meta["properties"]["promptflow.section"]
|
||||
# check if the README.md exits
|
||||
# skip checking README exists due to template flows don't have README.md.
|
||||
if section_type != "template":
|
||||
if not os.path.exists(Path(flow_dir) / "README.md"):
|
||||
flow_validation_errors.append(
|
||||
f"README.md not found in {flow_dir}. Please add README.md to the flow.")
|
||||
continue
|
||||
else:
|
||||
# Check Links in Markdown Files of Flows,
|
||||
# make sure it opens a new browser tab instead of refreshing the current page.
|
||||
def extract_links_from_file(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
content = file.read()
|
||||
html = markdown(content)
|
||||
soup = BeautifulSoup(html, "html.parser")
|
||||
return soup
|
||||
|
||||
def check_links(soup):
|
||||
valid_links = True
|
||||
links = soup.find_all("a")
|
||||
for link in links:
|
||||
if link.get("target") != "_blank":
|
||||
log_debug(f'Invalid link syntax: {link}')
|
||||
valid_links = False
|
||||
return valid_links
|
||||
|
||||
readme_file = os.path.join(flow_dir, "README.md")
|
||||
log_debug(f"Checking links in {readme_file}")
|
||||
soup = extract_links_from_file(readme_file)
|
||||
valid_links = check_links(soup)
|
||||
if not valid_links:
|
||||
flow_validation_errors.append(
|
||||
f"Some links in {flow_dir}'s README file do not follow the required syntax. "
|
||||
"To ensure that links in the flow's README file open in a new browser tab "
|
||||
"instead of refreshing the current page when users view the sample introduction, "
|
||||
"please use the following syntax: <a href='http://example.com' target='_blank'>link text</a>."
|
||||
)
|
||||
continue
|
||||
# Call MT to create flow
|
||||
log_debug(
|
||||
f"Starting to create/update flow. Flow dir: {Path(flow_dir).name}.")
|
||||
flow = load_flow(source=flow_dir)
|
||||
properties = flow_meta.get("properties", None)
|
||||
if properties and "promptflow.batch_inputs" in properties:
|
||||
input_path = properties["promptflow.batch_inputs"]
|
||||
samples_file = Path(flow_dir) / input_path
|
||||
if samples_file.exists():
|
||||
with open(samples_file, "r", encoding="utf-8") as fp:
|
||||
properties["update_promptflow.batch_inputs"] = json.loads(
|
||||
fp.read())
|
||||
|
||||
flow_operations._resolve_arm_id_or_upload_dependencies_to_file_share(
|
||||
flow)
|
||||
log_debug(f"FlowDefinitionFilePath: {flow.path}")
|
||||
|
||||
create_flow_payload = flow_utils.construct_create_flow_payload_of_new_contract(
|
||||
flow, flow_meta, properties)
|
||||
debug_output(create_flow_payload, "create_flow_payload",
|
||||
Path(flow_dir).name, args.local)
|
||||
create_flow_result = mt_client.create_or_update_flow(
|
||||
create_flow_payload)
|
||||
experiment_id = create_flow_result["experimentId"]
|
||||
flows_creation_info.update({Path(flow_dir).name: (
|
||||
flow_dag, create_flow_result, flow_type, section_type)})
|
||||
|
||||
if create_flow_result['flowId'] is None:
|
||||
raise Exception(
|
||||
f"Flow id is None when creating/updating mode {flow_dir}. Please make sure the flow is valid")
|
||||
debug_output(create_flow_result, "create_flow_result",
|
||||
Path(flow_dir).name, args.local)
|
||||
flow_link = flow_utils.get_flow_link(create_flow_result, ux_endpoint, args.subscription_id,
|
||||
args.resource_group, args.workspace_name, experiment_id, args.ux_flight)
|
||||
log_debug(f"Flow link to Azure Machine Learning Portal: {flow_link}")
|
||||
|
||||
if len(flow_validation_errors) > 0:
|
||||
log_debug(
|
||||
"Promptflow CI failed due to the following flow validation errors:", True)
|
||||
for failure in flow_validation_errors:
|
||||
log_error(failure)
|
||||
exit(1)
|
||||
return flows_creation_info
|
||||
|
||||
|
||||
def _resolve_data_to_asset_id(test_data):
|
||||
"""Resolve data to asset id."""
|
||||
from azure.ai.ml._artifacts._artifact_utilities import _upload_and_generate_remote_uri
|
||||
from azure.ai.ml.constants._common import AssetTypes
|
||||
|
||||
def _get_data_type(_data):
|
||||
if os.path.isdir(_data):
|
||||
return AssetTypes.URI_FOLDER
|
||||
else:
|
||||
return AssetTypes.URI_FILE
|
||||
|
||||
if is_remote_uri(test_data):
|
||||
# Pass through ARM id or remote url
|
||||
return test_data
|
||||
|
||||
if os.path.exists(test_data): # absolute local path, upload, transform to remote url
|
||||
data_type = _get_data_type(test_data)
|
||||
test_data = _upload_and_generate_remote_uri(
|
||||
run_operations._operation_scope,
|
||||
run_operations._datastore_operations,
|
||||
test_data,
|
||||
datastore_name=run_operations._workspace_default_datastore,
|
||||
show_progress=run_operations._show_progress,
|
||||
)
|
||||
if data_type == AssetTypes.URI_FOLDER and test_data and not test_data.endswith("/"):
|
||||
test_data = test_data + "/"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Local path {test_data!r} not exist. "
|
||||
"If it's remote data, only data with azureml prefix or remote url is supported."
|
||||
)
|
||||
return test_data
|
||||
|
||||
|
||||
def check_flow_run_status(
|
||||
flow_runs_to_check,
|
||||
submitted_run_identifiers,
|
||||
submitted_flow_run_links,
|
||||
submitted_flow_run_ids,
|
||||
check_run_status_interval,
|
||||
check_run_status_max_attempts
|
||||
):
|
||||
"""Check flow run status."""
|
||||
for flow_run_identifier in flow_runs_to_check:
|
||||
flow_id, flow_run_id = flow_utils.resolve_flow_run_identifier(
|
||||
flow_run_identifier)
|
||||
flow_run_link = flow_utils.construct_flow_run_link(ux_endpoint, args.subscription_id,
|
||||
args.resource_group, args.workspace_name,
|
||||
experiment_id, flow_id, flow_run_id)
|
||||
log_debug(f"Start checking flow run {flow_run_id} run, link to Azure Machine Learning Portal: "
|
||||
f"{flow_run_link}")
|
||||
for flow_run_id, flow_run_link in zip(flow_runs_to_check, submitted_flow_run_links):
|
||||
log_debug(
|
||||
f"Start checking flow run {flow_run_id} run, {flow_run_link}")
|
||||
current_attempt = 0
|
||||
while current_attempt < check_run_status_max_attempts:
|
||||
bulk_test_run = run_workspace.get_run(run_id=flow_run_id)
|
||||
if bulk_test_run.status == "Completed":
|
||||
submitted_run_identifiers.remove(flow_run_identifier)
|
||||
submitted_flow_run_ids.remove(flow_run_id)
|
||||
break
|
||||
elif bulk_test_run.status == "Failed":
|
||||
submitted_run_identifiers.remove(flow_run_identifier)
|
||||
failed_flow_runs.update({flow_run_identifier: flow_run_link})
|
||||
submitted_flow_run_ids.remove(flow_run_id)
|
||||
failed_flow_runs.update({flow_run_id: flow_run_link})
|
||||
break
|
||||
elif bulk_test_run.status == "Canceled":
|
||||
submitted_run_identifiers.remove(flow_run_identifier)
|
||||
failed_flow_runs.update({flow_run_identifier: flow_run_link})
|
||||
submitted_flow_run_ids.remove(flow_run_id)
|
||||
failed_flow_runs.update({flow_run_id: flow_run_link})
|
||||
break
|
||||
|
||||
current_attempt += 1
|
||||
|
@ -306,16 +157,12 @@ if __name__ == "__main__":
|
|||
default="https://eastus.api.azureml.ms/flow")
|
||||
parser.add_argument('--flow_submit_mode', type=str, default="sync")
|
||||
parser.add_argument('--run_time', type=str, default="default-mir")
|
||||
parser.add_argument('--skipped_flows', type=str, default="bring_your_own_data_qna")
|
||||
parser.add_argument('--skipped_flows', type=str,
|
||||
default="bring_your_own_data_qna,template_chat_flow")
|
||||
# Skip bring_your_own_data_qna test, the flow has a bug.
|
||||
# Bug 2773738: Add retry when ClientAuthenticationError
|
||||
# https://msdata.visualstudio.com/Vienna/_workitems/edit/2773738
|
||||
parser.add_argument(
|
||||
"--local",
|
||||
help="local debug mode, will use interactive login authentication, and output the request "
|
||||
"response to local files",
|
||||
action='store_true'
|
||||
)
|
||||
# Skip template_chat_flow because not able to extract samples.json for test.
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get changed models folder or all models folder
|
||||
|
@ -332,9 +179,15 @@ if __name__ == "__main__":
|
|||
log_error(f"No change in {MODELS_ROOT}, skip flow testing.")
|
||||
exit(0)
|
||||
|
||||
if args.skipped_flows != "":
|
||||
skipped_flows = args.skipped_flows.split(",")
|
||||
skipped_flows = [flow.replace("_", "-") for flow in skipped_flows]
|
||||
log_debug(f"Skipped flows: {skipped_flows}")
|
||||
flows_dirs = [flow_dir for flow_dir in changed_models if Path(
|
||||
flow_dir).name not in skipped_flows]
|
||||
# Check download models
|
||||
errors = []
|
||||
for model_dir in changed_models:
|
||||
for model_dir in flows_dirs:
|
||||
try:
|
||||
validate_downlaod(model_dir)
|
||||
except Exception as e:
|
||||
|
@ -349,14 +202,6 @@ if __name__ == "__main__":
|
|||
# Check run flows
|
||||
handled_failures = []
|
||||
|
||||
# Filter out skipped flows
|
||||
if args.skipped_flows != "":
|
||||
skipped_flows = args.skipped_flows.split(",")
|
||||
skipped_flows = [flow.replace("_", "-") for flow in skipped_flows]
|
||||
log_debug(f"Skipped flows: {skipped_flows}")
|
||||
flows_dirs = [flow_dir for flow_dir in changed_models if Path(
|
||||
flow_dir).name not in skipped_flows]
|
||||
|
||||
flows_dirs = [Path(os.path.join(dir, TEST_FOLDER))
|
||||
for dir in flows_dirs]
|
||||
|
||||
|
@ -365,104 +210,37 @@ if __name__ == "__main__":
|
|||
log_debug("No flow code change, skip flow testing.")
|
||||
exit(0)
|
||||
|
||||
ux_endpoint = args.ux_endpoint
|
||||
runtime_name = args.run_time
|
||||
|
||||
mt_client = get_mt_client(
|
||||
args.subscription_id,
|
||||
args.resource_group,
|
||||
args.workspace_name,
|
||||
args.tenant_id,
|
||||
args.client_id,
|
||||
args.client_secret,
|
||||
args.local,
|
||||
args.mt_service_route
|
||||
)
|
||||
|
||||
_dump_workspace_config(args.subscription_id,
|
||||
args.resource_group, args.workspace_name)
|
||||
credential = DefaultAzureCredential(additionally_allowed_tenants=[
|
||||
"*"], exclude_shared_token_cache_credential=True)
|
||||
pf_client = PFClient.from_config(credential=credential)
|
||||
flow_operations = pf_client._flows
|
||||
run_operations = pf_client._runs
|
||||
|
||||
# region: Step1. create/update flows, and store flow creation info in flows_creation_info
|
||||
# region: Step1. create/update flow yamls, run yamls.
|
||||
try:
|
||||
# add node variant for llm node
|
||||
tmp_folder_path = Path(tempfile.mkdtemp())
|
||||
log_debug(f"tmp folder path: {tmp_folder_path}")
|
||||
flows_dirs = flow_utils._assign_flow_values(
|
||||
flows_dirs, tmp_folder_path)
|
||||
|
||||
flows_creation_info = create_flows(flows_dirs)
|
||||
flows_dirs = flow_utils._assign_flow_values(flows_dirs)
|
||||
flow_utils._create_run_yamls(flows_dirs)
|
||||
except Exception as e:
|
||||
log_error("Error when creating flow")
|
||||
raise e
|
||||
finally:
|
||||
shutil.rmtree(tmp_folder_path)
|
||||
# endregion
|
||||
|
||||
# region: Step2. submit bulk test runs and evaluation flow bulk test runs asynchronously based on the
|
||||
# flows_creation_info
|
||||
submitted_flow_run_identifiers = set()
|
||||
submit_interval = 2 # seconds
|
||||
for flow_dir_name, creation_info in flows_creation_info.items():
|
||||
time.sleep(submit_interval)
|
||||
flow_dir = Path(os.path.join(MODELS_ROOT, flow_dir_name, TEST_FOLDER))
|
||||
flow_dag = creation_info[0]
|
||||
flow_create_result = creation_info[1]
|
||||
flow_type = creation_info[2]
|
||||
section_type = creation_info[3]
|
||||
flow_id = flow_create_result['flowId']
|
||||
flow_resource_id = flow_create_result["flowResourceId"]
|
||||
flow_name = flow_create_result["flowName"]
|
||||
# Call MT to submit flow
|
||||
# Skip template flow
|
||||
if (section_type == 'template'):
|
||||
log_debug(f"Skipped template flow: {flow_dir}. Flow id: {flow_id}")
|
||||
continue
|
||||
sample_path = flow_dir / "samples.json"
|
||||
log_debug(f"Sample input file path: {sample_path}")
|
||||
if not sample_path.exists():
|
||||
raise Exception(
|
||||
f"Sample input file path doesn't exist when submitting flow {flow_dir}")
|
||||
batch_data_inputs = _resolve_data_to_asset_id(sample_path)
|
||||
log_debug(
|
||||
f"\nStarting to submit bulk test run. Flow dir: {flow_dir_name}. Flow id: {flow_id}")
|
||||
submit_flow_payload = flow_utils.construct_submit_flow_payload_of_new_contract(
|
||||
flow_id, batch_data_inputs, runtime_name, flow_dag, args.flow_submit_mode
|
||||
)
|
||||
|
||||
experiment_id = flow_create_result["experimentId"]
|
||||
try:
|
||||
submit_flow_result, flow_run_id, _ = mt_client.submit_flow(
|
||||
submit_flow_payload, experiment_id)
|
||||
bulk_test_run_id = submit_flow_result["bulkTestId"]
|
||||
flow_run_ids = flow_utils.get_flow_run_ids(submit_flow_result)
|
||||
except Exception as e:
|
||||
failure_message = f"Submit bulk test run failed. Flow dir: {flow_dir}. Flow id: {flow_id}. Error: {e}"
|
||||
log_error(failure_message)
|
||||
handled_failures.append(failure_message)
|
||||
else:
|
||||
debug_output(submit_flow_result, "submit_flow_result",
|
||||
flow_dir.name, args.local)
|
||||
log_debug(
|
||||
f"All the flow run links for bulk test: {bulk_test_run_id}")
|
||||
for run_id in flow_run_ids:
|
||||
submitted_flow_run_identifiers.add(
|
||||
flow_utils.create_flow_run_identifier(flow_id, run_id))
|
||||
flow_run_link = flow_utils.get_flow_run_link(submit_flow_result, ux_endpoint,
|
||||
args.subscription_id, args.resource_group,
|
||||
args.workspace_name, experiment_id, run_id)
|
||||
log_debug(
|
||||
f"Flow run link for run {run_id} to Azure Machine Learning Portal: {flow_run_link}")
|
||||
# region: Step2. submit flow runs using pfazure
|
||||
submitted_flow_run_ids = []
|
||||
submitted_flow_run_links = []
|
||||
results, handled_failures = flow_utils.submit_flow_runs_using_pfazure(
|
||||
flows_dirs,
|
||||
args.subscription_id,
|
||||
args.resource_group,
|
||||
args.workspace_name
|
||||
)
|
||||
for key, val in results.items():
|
||||
submitted_flow_run_ids.append(key)
|
||||
submitted_flow_run_links.append(val)
|
||||
# endregion
|
||||
|
||||
# region: Step3. check the submitted run status
|
||||
check_run_status_interval = 30 # seconds
|
||||
check_run_status_max_attempts = 30 # times
|
||||
flow_runs_count = len(submitted_flow_run_identifiers)
|
||||
flow_runs_count = len(submitted_flow_run_ids)
|
||||
flow_runs_to_check = copy.deepcopy(submitted_flow_run_ids)
|
||||
failed_flow_runs = {} # run key : flow_run_link
|
||||
failed_evaluation_runs = {} # run key : evaluation_run_link
|
||||
if flow_runs_count == 0:
|
||||
|
@ -470,17 +248,16 @@ if __name__ == "__main__":
|
|||
"\nNo bulk test run or bulk test evaluation run need to check status")
|
||||
|
||||
run_workspace = Workspace.get(
|
||||
name=run_operations._operation_scope.workspace_name,
|
||||
subscription_id=run_operations._operation_scope.subscription_id,
|
||||
resource_group=run_operations._operation_scope.resource_group_name,
|
||||
name=args.workspace_name,
|
||||
subscription_id=args.subscription_id,
|
||||
resource_group=args.resource_group,
|
||||
)
|
||||
|
||||
flow_runs_to_check = copy.deepcopy(submitted_flow_run_identifiers)
|
||||
log_debug(f"\n{flow_runs_count} bulk test runs need to check status.")
|
||||
check_flow_run_status(flow_runs_to_check, submitted_flow_run_identifiers,
|
||||
check_flow_run_status(flow_runs_to_check, submitted_flow_run_links, submitted_flow_run_ids,
|
||||
check_run_status_interval, check_run_status_max_attempts)
|
||||
|
||||
if len(submitted_flow_run_identifiers) > 0:
|
||||
if len(submitted_flow_run_ids) > 0:
|
||||
failure_message = f"Not all bulk test runs or bulk test evaluation runs are completed after " \
|
||||
f"{check_run_status_max_attempts} attempts. " \
|
||||
f"Please check the run status on Azure Machine Learning Portal."
|
||||
|
@ -491,12 +268,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
handled_failures.append(failure_message)
|
||||
|
||||
for flow_run_identifier in submitted_flow_run_identifiers:
|
||||
flow_id, flow_run_id = flow_utils.resolve_flow_run_identifier(
|
||||
flow_run_identifier)
|
||||
flow_run_link = flow_utils.construct_flow_run_link(ux_endpoint, args.subscription_id,
|
||||
args.resource_group, args.workspace_name,
|
||||
experiment_id, flow_id, flow_run_id)
|
||||
for flow_run_id, flow_run_link in submitted_flow_run_ids.items():
|
||||
log_debug(
|
||||
f"Flow run link for run {flow_run_id} to Azure Machine Learning Portal: {flow_run_link}")
|
||||
|
||||
|
@ -505,13 +277,18 @@ if __name__ == "__main__":
|
|||
"Please check the run error on Azure Machine Learning Portal."
|
||||
log_error(failure_message, True)
|
||||
handled_failures.append(failure_message)
|
||||
for flow_run_key, flow_run_link in failed_flow_runs.items():
|
||||
for flow_run_id, flow_run_link in failed_flow_runs.items():
|
||||
log_error(
|
||||
f"Bulk test run link to Azure Machine Learning Portal: {flow_run_link}")
|
||||
for flow_run_key, evaluation_run_link in failed_evaluation_runs.items():
|
||||
for flow_run_id, evaluation_run_link in failed_evaluation_runs.items():
|
||||
log_error(
|
||||
f"Bulk test evaluation run link to Azure Machine Learning Portal: {evaluation_run_link}")
|
||||
elif len(submitted_flow_run_identifiers) == 0:
|
||||
log_error("The links are scrubbed due to compliance, for how to debug the flow, please refer "
|
||||
"to https://msdata.visualstudio.com/Vienna/_git/PromptFlow?path=/docs/"
|
||||
"sharing-your-flows-in-prompt-flow-gallery.md&_a=preview&anchor=2.-how-to-debug-a-failed"
|
||||
"-run-in--%60validate-prompt-flows%60-step-of-%5Bpromptflow-ci"
|
||||
"%5D(https%3A//github.com/azure/azureml-assets/actions/workflows/promptflow-ci.yml)")
|
||||
elif len(submitted_flow_run_ids) == 0:
|
||||
log_debug(
|
||||
f"\nRun status checking completed. {flow_runs_count} flow runs completed.")
|
||||
# Fail CI if there are failures.
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Auth utils."""
|
||||
|
||||
import uuid
|
||||
|
||||
from azureml.core.authentication import AzureCliAuthentication
|
||||
|
||||
|
||||
def get_azure_cli_authentication_header(request_id=None):
|
||||
"""Get login auth header."""
|
||||
interactive_auth = AzureCliAuthentication()
|
||||
header = interactive_auth.get_authentication_header()
|
||||
if request_id is None:
|
||||
request_id = str(uuid.uuid4())
|
||||
# add request id to header for tracking
|
||||
header["x-ms-client-request-id"] = request_id
|
||||
header["Content-Type"] = "application/json"
|
||||
header["Accept"] = "application/json"
|
||||
|
||||
return header
|
|
@ -4,76 +4,28 @@
|
|||
"""Flow utils."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from unittest import mock
|
||||
import yaml
|
||||
import concurrent.futures
|
||||
import json
|
||||
|
||||
from utils.logging_utils import log_error, log_debug
|
||||
from utils.logging_utils import log_debug, log_error
|
||||
from utils.utils import run_command
|
||||
|
||||
|
||||
def create_flow_run_identifier(flow_id, flow_run_id):
|
||||
"""Generate the global unique flow run identifier."""
|
||||
return f"{flow_id}:{flow_run_id}"
|
||||
|
||||
|
||||
def resolve_flow_run_identifier(flow_run_identifier):
|
||||
"""Resolve the flow run identifier to flow id and flow run id."""
|
||||
return flow_run_identifier.split(":")[0], flow_run_identifier.split(":")[1]
|
||||
|
||||
|
||||
def _validate_meta(meta, flow_dir):
|
||||
"""Validate meta type."""
|
||||
if meta["type"] not in ["standard", "evaluate", "chat", "rag"]:
|
||||
raise ValueError(f"Unknown type in meta.json. model dir: {flow_dir}.")
|
||||
stage = meta["properties"]["promptflow.stage"]
|
||||
if stage not in ["test", "prod", "disabled"]:
|
||||
raise ValueError(f"Unknown stage in meta.json. flow dir: {flow_dir}.")
|
||||
|
||||
|
||||
def _general_copy(src, dst, make_dirs=True):
|
||||
"""Call _copy to copy."""
|
||||
if make_dirs:
|
||||
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
||||
if hasattr(os, "listxattr"):
|
||||
with mock.patch("shutil._copyxattr", return_value=[]):
|
||||
shutil.copy2(src, dst)
|
||||
else:
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
|
||||
def _copy(src: Path, dst: Path) -> None:
|
||||
"""Copy files."""
|
||||
if not src.exists():
|
||||
raise ValueError(f"Path {src} does not exist.")
|
||||
if src.is_file():
|
||||
_general_copy(src, dst)
|
||||
if src.is_dir():
|
||||
for name in src.glob("*"):
|
||||
_copy(name, dst / name.name)
|
||||
|
||||
|
||||
def _assign_flow_values(flow_dirs, tmp_folder_path):
|
||||
def _assign_flow_values(flow_dirs):
|
||||
"""Assign the flow values and update flow.dag.yaml."""
|
||||
log_debug("\n=======Start overriding values for flows=======")
|
||||
updated_bulk_test_main_flows_dirs = []
|
||||
for flow_dir in flow_dirs:
|
||||
dst_path = (tmp_folder_path / flow_dir.parents[0].name).resolve()
|
||||
_copy(Path(flow_dir), dst_path)
|
||||
log_debug(dst_path)
|
||||
updated_bulk_test_main_flows_dirs.append(dst_path)
|
||||
|
||||
for flow_dir in updated_bulk_test_main_flows_dirs:
|
||||
for flow_dir in flow_dirs:
|
||||
flow_dir_name = flow_dir.name
|
||||
flow_dir_name = flow_dir_name.replace("-", "_")
|
||||
|
||||
with open(Path(flow_dir) / "flow.dag.yaml", "r") as dag_file:
|
||||
flow_dag = yaml.safe_load(dag_file)
|
||||
# Override connection/inputs in nodes
|
||||
log_debug(f"Start overriding values for nodes for '{flow_dir.name}'.")
|
||||
log_debug(f"Start overriding values for nodes for '{flow_dir}'.")
|
||||
for flow_node in flow_dag["nodes"]:
|
||||
if "connection" in flow_node:
|
||||
flow_node["connection"] = "aoai_connection"
|
||||
|
@ -85,153 +37,87 @@ def _assign_flow_values(flow_dirs, tmp_folder_path):
|
|||
flow_node["inputs"]["deployment_name"] = "gpt-35-turbo"
|
||||
if "connection" in flow_node["inputs"]:
|
||||
flow_node["inputs"]["connection"] = "aoai_connection"
|
||||
if "searchConnection" in flow_node["inputs"]:
|
||||
flow_node["inputs"]["searchConnection"] = "AzureAISearch"
|
||||
with open(flow_dir / "flow.dag.yaml", "w", encoding="utf-8") as dag_file:
|
||||
yaml.dump(flow_dag, dag_file, allow_unicode=True)
|
||||
|
||||
if not os.path.exists(Path(flow_dir)/"samples.json"):
|
||||
with open(flow_dir/"samples.json", 'w', encoding="utf-8") as sample_file:
|
||||
samples = []
|
||||
sample = {}
|
||||
for key, val in flow_dag["inputs"].items():
|
||||
value = val.get("default")
|
||||
if isinstance(value, list):
|
||||
if not value:
|
||||
value.append("default")
|
||||
elif isinstance(value, str):
|
||||
if value == "":
|
||||
value = "default"
|
||||
sample[key] = value
|
||||
samples.append(sample)
|
||||
json.dump(sample, sample_file, indent=4)
|
||||
log_debug("=======Complete overriding values for flows=======\n")
|
||||
return updated_bulk_test_main_flows_dirs
|
||||
return flow_dirs
|
||||
|
||||
|
||||
def construct_create_flow_payload_of_new_contract(flow, flow_meta, properties):
|
||||
"""Construct create flow payload."""
|
||||
flow_type = flow_meta.get("type", None)
|
||||
if flow_type:
|
||||
mapping = {
|
||||
"standard": "default",
|
||||
"evaluate": "evaluation",
|
||||
"chat": "chat",
|
||||
"rag": "rag"
|
||||
def _create_run_yamls(flow_dirs):
|
||||
"""Create run.yml."""
|
||||
log_debug("\n=======Start creating run.yaml for flows=======")
|
||||
run_yaml = {
|
||||
"$schema": "https://azuremlschemas.azureedge.net/promptflow/latest/Run.schema.json",
|
||||
"flow": '.',
|
||||
"data": 'samples.json'
|
||||
}
|
||||
for flow_dir in flow_dirs:
|
||||
flow_dir_name = flow_dir.name
|
||||
flow_dir_name = flow_dir_name.replace("-", "_")
|
||||
with open(flow_dir / "run.yml", "w", encoding="utf-8") as dag_file:
|
||||
yaml.dump(run_yaml, dag_file, allow_unicode=True)
|
||||
log_debug("=======Complete creating run.yaml for flows=======\n")
|
||||
return
|
||||
|
||||
|
||||
def submit_func(run_path, sub, rg, ws):
|
||||
"""Worker function to submit flow run."""
|
||||
command = f"pfazure run create --file {run_path} --subscription {sub} -g {rg} -w {ws}"
|
||||
res = run_command(command)
|
||||
res = res.stdout.split('\n')
|
||||
return res
|
||||
|
||||
|
||||
def get_run_id_and_url(res):
|
||||
"""Resolve run_id an url from log."""
|
||||
run_id = ""
|
||||
portal_url = ""
|
||||
for line in res:
|
||||
log_debug(line)
|
||||
if ('"portal_url":' in line):
|
||||
match = re.search(r'/run/(.*?)/details', line)
|
||||
if match:
|
||||
portal_url = line.strip()
|
||||
run_id = match.group(1)
|
||||
log_debug(f"runId: {run_id}")
|
||||
return run_id, portal_url
|
||||
|
||||
|
||||
def submit_flow_runs_using_pfazure(flow_dirs, sub, rg, ws):
|
||||
"""Multi thread submit flow run using pfazure."""
|
||||
results = {}
|
||||
handled_failures = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
|
||||
futures = {
|
||||
executor.submit(submit_func, os.path.join(flow_dir, 'run.yml'), sub, rg, ws): flow_dir
|
||||
for flow_dir in flow_dirs
|
||||
}
|
||||
flow_type = mapping[flow_type]
|
||||
|
||||
return {
|
||||
"flowName": flow_meta.get("display_name", None),
|
||||
"description": flow_meta.get("description", None),
|
||||
"tags": flow_meta.get("tags", None),
|
||||
"flowType": flow_type,
|
||||
"details": properties.get("promptflow.details.source", None) if properties else None,
|
||||
"flowRunSettings": {
|
||||
"batch_inputs": properties.get("update_promptflow.batch_inputs", None) if properties else None,
|
||||
},
|
||||
"flowDefinitionFilePath": flow.path,
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
|
||||
def construct_submit_flow_payload_of_new_contract(
|
||||
flow_id,
|
||||
batch_data_inputs,
|
||||
runtime_name,
|
||||
flow_dag,
|
||||
flow_submit_mode
|
||||
):
|
||||
"""Construct submit flow payload."""
|
||||
flow_run_id = f"run_{datetime.now().strftime('%Y%m%d%H%M%S')}_{random.randint(100000, 999999)}"
|
||||
tuning_node_names = [node["name"]
|
||||
for node in flow_dag["nodes"] if "use_variants" in node]
|
||||
submit_flow_payload = {
|
||||
"flowId": flow_id,
|
||||
"flowRunId": flow_run_id,
|
||||
"flowSubmitRunSettings": {
|
||||
"runtimeName": runtime_name,
|
||||
"runMode": "BulkTest",
|
||||
"batchDataInput": {"dataUri": batch_data_inputs},
|
||||
# Need to populate this field for the LLM node with variants
|
||||
"tuningNodeNames": tuning_node_names,
|
||||
},
|
||||
"asyncSubmission": True if flow_submit_mode == "async" else False,
|
||||
"useWorkspaceConnection": True,
|
||||
"useFlowSnapshotToSubmit": True,
|
||||
}
|
||||
return submit_flow_payload
|
||||
|
||||
|
||||
def construct_flow_link(aml_resource_uri, subscription, resource_group, workspace, experiment_id, flow_id, ux_flight):
|
||||
"""Construct flow link."""
|
||||
flow_link_format = (
|
||||
"{aml_resource_uri}/prompts/flow/{experiment_id}/{flow_id}/details?wsid=/subscriptions/"
|
||||
"{subscription}/resourceGroups/{resource_group}/providers/Microsoft.MachineLearningServices/"
|
||||
"workspaces/{workspace}&flight={ux_flight}"
|
||||
)
|
||||
return flow_link_format.format(
|
||||
aml_resource_uri=aml_resource_uri,
|
||||
subscription=subscription,
|
||||
resource_group=resource_group,
|
||||
workspace=workspace,
|
||||
experiment_id=experiment_id,
|
||||
flow_id=flow_id,
|
||||
ux_flight=ux_flight,
|
||||
)
|
||||
|
||||
|
||||
def get_flow_link(create_flow_response_json, aml_resource_uri, subscription, resource_group, workspace, experiment_id,
|
||||
ux_flight):
|
||||
"""Get flow link."""
|
||||
flow_id = create_flow_response_json["flowId"]
|
||||
return construct_flow_link(aml_resource_uri, subscription, resource_group, workspace, experiment_id, flow_id,
|
||||
ux_flight)
|
||||
|
||||
|
||||
def get_flow_run_ids(bulk_test_response_json):
|
||||
"""Get flow run ids from response."""
|
||||
bulk_test_id = bulk_test_response_json["bulkTestId"]
|
||||
flow_run_logs = bulk_test_response_json["flowRunLogs"]
|
||||
flow_run_ids = [run_id for run_id in list(
|
||||
flow_run_logs.keys()) if run_id != bulk_test_id]
|
||||
log_debug(f"flow_run_ids in utils: {flow_run_ids}")
|
||||
return flow_run_ids
|
||||
|
||||
|
||||
def construct_flow_run_link(
|
||||
aml_resource_uri, subscription, resource_group, workspace, experiment_id, flow_id, flow_run_id
|
||||
):
|
||||
"""Construct flow run link."""
|
||||
bulk_test_run_link_format = (
|
||||
"{aml_resource_uri}/prompts/flow/{experiment_id}/{flow_id}/run/{flow_run_id}/details?wsid=/"
|
||||
"subscriptions/{subscription}/resourceGroups/{resource_group}/providers/"
|
||||
"Microsoft.MachineLearningServices/workspaces/{workspace}&flight=promptflow"
|
||||
)
|
||||
return bulk_test_run_link_format.format(
|
||||
aml_resource_uri=aml_resource_uri,
|
||||
subscription=subscription,
|
||||
resource_group=resource_group,
|
||||
workspace=workspace,
|
||||
experiment_id=experiment_id,
|
||||
flow_id=flow_id,
|
||||
flow_run_id=flow_run_id,
|
||||
)
|
||||
|
||||
|
||||
def get_flow_run_link(
|
||||
bulk_test_response_json, aml_resource_uri, subscription, resource_group, workspace, experiment_id, flow_run_id
|
||||
):
|
||||
"""Get flow run link."""
|
||||
flow_run_resource_id = bulk_test_response_json["flowRunResourceId"]
|
||||
flow_id, _ = _resolve_flow_run_resource_id(flow_run_resource_id)
|
||||
link = construct_flow_run_link(
|
||||
aml_resource_uri=aml_resource_uri,
|
||||
subscription=subscription,
|
||||
resource_group=resource_group,
|
||||
workspace=workspace,
|
||||
experiment_id=experiment_id,
|
||||
flow_id=flow_id,
|
||||
flow_run_id=flow_run_id,
|
||||
)
|
||||
return link
|
||||
|
||||
|
||||
def _resolve_flow_run_resource_id(flow_run_resource_id):
|
||||
"""Get flow id and flow run id from flow run resource id."""
|
||||
if flow_run_resource_id.startswith("azureml://"):
|
||||
flow_run_resource_id = flow_run_resource_id[len("azureml://"):]
|
||||
elif flow_run_resource_id.startswith("azureml:/"):
|
||||
flow_run_resource_id = flow_run_resource_id[len("azureml:/"):]
|
||||
|
||||
pairs = re.findall(r"([^\/]+)\/([^\/]+)", flow_run_resource_id)
|
||||
flows = [pair for pair in pairs if pair[0] == "flows"]
|
||||
flow_runs = [pair for pair in pairs if pair[0] == "flowRuns"]
|
||||
if len(flows) == 0 or len(flow_runs) == 0:
|
||||
log_error(
|
||||
f"Resolve flow run resource id [{flow_run_resource_id}] failed")
|
||||
return None, None
|
||||
else:
|
||||
return flows[0][1], flow_runs[0][1]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
flow_dir = futures[future]
|
||||
res = future.result()
|
||||
run_id, portal_url = get_run_id_and_url(res)
|
||||
results[run_id] = portal_url
|
||||
except Exception as exc:
|
||||
failure_message = f"Submit test run failed. Flow dir: {flow_dir}. Error: {exc}."
|
||||
log_error(failure_message)
|
||||
handled_failures.append(failure_message)
|
||||
return results, handled_failures
|
||||
|
|
|
@ -1,131 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""MTClient class."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
|
||||
from .authentication import get_azure_cli_authentication_header
|
||||
from .retry_helper import retry_helper
|
||||
from .logging_utils import log_debug
|
||||
|
||||
|
||||
class MTClient:
|
||||
"""MTClient class."""
|
||||
|
||||
flow_api_endpoint = "{MTServiceRoute}/api/subscriptions/{SubscriptionId}/resourceGroups/{ResourceGroupName}/" \
|
||||
"providers/Microsoft.MachineLearningServices/workspaces/{WorkspaceName}/flows"
|
||||
create_flow_api_format = "{0}/?experimentId={1}"
|
||||
create_flow_from_sample_api_format = "{0}/fromsample?experimentId={1}"
|
||||
list_flows_api_format = "{0}/?experimentId={1}&ownedOnly={2}&flowType={3}"
|
||||
submit_flow_api_format = "{0}/submit?experimentId={1}&endpointName={2}"
|
||||
submit_flow_api_without_endpoint_name_format = "{0}/submit?experimentId={1}"
|
||||
list_flow_runs_api_format = "{0}/{1}/runs?experimentId={2}&bulkTestId={3}"
|
||||
get_flow_run_status_api_format = "{0}/{1}/{2}/status?experimentId={3}"
|
||||
list_bulk_tests_api_format = "{0}/{1}/bulkTests"
|
||||
get_bulk_tests_api_format = "{0}/{1}/bulkTests/{2}"
|
||||
deploy_flow_api_format = "{0}/deploy?asyncCall={1}"
|
||||
get_samples_api_format = "{0}/samples"
|
||||
|
||||
def __init__(self, mt_service_route,
|
||||
subscription_id, resource_group_name, workspace_name,
|
||||
tenant_id=None, client_id=None, client_secret=None):
|
||||
"""MT Client init."""
|
||||
self.mt_service_route = mt_service_route
|
||||
self.subscription_id = subscription_id
|
||||
self.resource_group_name = resource_group_name
|
||||
self.workspace_name = workspace_name
|
||||
self.api_endpoint = self.flow_api_endpoint.format(
|
||||
MTServiceRoute=self.mt_service_route,
|
||||
SubscriptionId=self.subscription_id,
|
||||
ResourceGroupName=self.resource_group_name,
|
||||
WorkspaceName=self.workspace_name)
|
||||
self.tenant_id = tenant_id
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
def _request(self, method, url, **kwargs):
|
||||
"""MT client request call."""
|
||||
header = get_azure_cli_authentication_header()
|
||||
|
||||
resp = method(url, **{**kwargs, "headers": header})
|
||||
if method.__name__ == "post":
|
||||
log_debug(
|
||||
f"[Request] {method.__name__} API Request id: {header['x-ms-client-request-id']}")
|
||||
if resp.status_code != 200:
|
||||
raise requests.exceptions.HTTPError(
|
||||
f"{method.__name__} on url {url} failed with status code [{resp.status_code}. Error: {resp.json()}].",
|
||||
response=resp)
|
||||
return resp.json()
|
||||
|
||||
def _get(self, url):
|
||||
"""Mt client get request."""
|
||||
return self._request(requests.get, url)
|
||||
|
||||
def _post(self, url, json_body):
|
||||
"""Mt client post request."""
|
||||
return self._request(requests.post, url, json=json_body)
|
||||
|
||||
@retry_helper()
|
||||
def create_or_update_flow(self, json_body):
|
||||
"""Create flow."""
|
||||
url = self.api_endpoint
|
||||
result = self._post(url, json_body)
|
||||
return result
|
||||
|
||||
@retry_helper()
|
||||
def create_flow_from_sample(self, json_body, experiment_id):
|
||||
"""Create flow from sample json."""
|
||||
url = self.create_flow_from_sample_api_format.format(
|
||||
self.api_endpoint, experiment_id)
|
||||
result = self._post(url, json_body)
|
||||
return result
|
||||
|
||||
@retry_helper()
|
||||
def submit_flow(self, json_body, experiment_id):
|
||||
"""Submit flow with a created flow run id or evaluation flow run id."""
|
||||
url = self.submit_flow_api_without_endpoint_name_format.format(
|
||||
self.api_endpoint, experiment_id)
|
||||
# We need to update flow run id in case retry happens, submit same json body with same flow run id will cause
|
||||
# 409 error.
|
||||
# Update flow run id
|
||||
flow_run_id = f"run_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
evaluation_run_id = f"evaluate_{datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
json_body['flowRunId'] = flow_run_id
|
||||
# Update evaluation run id for BulkTest run
|
||||
if "evaluationFlowRunSettings" in json_body['flowSubmitRunSettings'] and "evaluation" in \
|
||||
json_body['flowSubmitRunSettings']['evaluationFlowRunSettings']:
|
||||
json_body['flowSubmitRunSettings']['evaluationFlowRunSettings']["evaluation"]["flowRunId"] = \
|
||||
evaluation_run_id
|
||||
result = self._post(url, json_body)
|
||||
|
||||
return result, flow_run_id, evaluation_run_id
|
||||
|
||||
@retry_helper()
|
||||
def get_run_status(self, experiment_id, flow_id, run_id):
|
||||
"""Get run status."""
|
||||
url = self.get_flow_run_status_api_format.format(
|
||||
self.api_endpoint, flow_id, run_id, experiment_id)
|
||||
result = self._get(url)
|
||||
return result
|
||||
|
||||
|
||||
def get_mt_client(
|
||||
subscription_id,
|
||||
resource_group,
|
||||
workspace_name,
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret,
|
||||
is_local=False,
|
||||
mt_service_route="https://eastus2euap.api.azureml.ms/flow") -> MTClient:
|
||||
"""Get mt client."""
|
||||
if (is_local):
|
||||
mt_client = MTClient(mt_service_route, subscription_id,
|
||||
resource_group, workspace_name)
|
||||
else:
|
||||
mt_client = MTClient(mt_service_route, subscription_id, resource_group, workspace_name, tenant_id, client_id,
|
||||
client_secret)
|
||||
return mt_client
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""Retry helper."""
|
||||
|
||||
import requests
|
||||
import time
|
||||
from functools import wraps
|
||||
from .logging_utils import log_debug, log_error, log_warning
|
||||
|
||||
|
||||
def retry_helper(retry_count=3):
|
||||
"""Retry helper wrapper."""
|
||||
def retry(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
sleep_time = [1, 2, 4]
|
||||
if retry_count > 3:
|
||||
sleep_time.extend([10 for i in range(3, retry_count)])
|
||||
for i in range(retry_count):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if result is None:
|
||||
log_debug(f"{func.__name__} returns None, sleep {sleep_time[i]}s and will retry for {i + 1} "
|
||||
f"attempt")
|
||||
time.sleep(sleep_time[i])
|
||||
else:
|
||||
return result
|
||||
except Exception as e:
|
||||
# do not retry for 401, 403
|
||||
if (e is requests.exceptions.HTTPError and e.response.status_code in [401, 403]) or \
|
||||
i == retry_count - 1:
|
||||
log_error(
|
||||
f"{func.__name__} failed after {retry_count} retry attempts. Error: {e}")
|
||||
raise e
|
||||
else:
|
||||
log_warning(f"{func.__name__} failed, will sleep {sleep_time[i]} seconds and retry for the "
|
||||
f"{i + 1} attempt. Error: {e}")
|
||||
time.sleep(sleep_time[i])
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
assert result is not None, f"Failed to {func.__name__}, which returns None"
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return retry
|
Загрузка…
Ссылка в новой задаче