Added get_latest_model method (#231)
This commit is contained in:
Родитель
58876338ed
Коммит
d531b2e3eb
|
@ -26,7 +26,7 @@ POSSIBILITY OF SUCH DAMAGE.
|
|||
from azureml.core import Run
|
||||
import argparse
|
||||
import traceback
|
||||
from util.model_helper import get_model_by_tag
|
||||
from util.model_helper import get_latest_model
|
||||
|
||||
run = Run.get_context()
|
||||
|
||||
|
@ -45,7 +45,7 @@ run = Run.get_context()
|
|||
# sources_dir = 'diabetes_regression'
|
||||
# path_to_util = os.path.join(".", sources_dir, "util")
|
||||
# sys.path.append(os.path.abspath(path_to_util)) # NOQA: E402
|
||||
# from model_helper import get_model_by_tag
|
||||
# from model_helper import get_latest_model
|
||||
# workspace_name = os.environ.get("WORKSPACE_NAME")
|
||||
# experiment_name = os.environ.get("EXPERIMENT_NAME")
|
||||
# resource_group = os.environ.get("RESOURCE_GROUP")
|
||||
|
@ -108,7 +108,7 @@ try:
|
|||
firstRegistration = False
|
||||
tag_name = 'experiment_name'
|
||||
|
||||
model = get_model_by_tag(
|
||||
model = get_latest_model(
|
||||
model_name, tag_name, exp.name, ws)
|
||||
|
||||
if (model is not None):
|
||||
|
|
|
@ -22,20 +22,20 @@ def get_current_workspace() -> Workspace:
|
|||
return experiment.workspace
|
||||
|
||||
|
||||
def get_model_by_tag(
|
||||
def get_latest_model(
|
||||
model_name: str,
|
||||
tag_name: str,
|
||||
tag_value: str,
|
||||
tag_name: str = None,
|
||||
tag_value: str = None,
|
||||
aml_workspace: Workspace = None
|
||||
) -> AMLModel:
|
||||
"""
|
||||
Retrieves and returns the latest model from the workspace
|
||||
by its name and tag.
|
||||
by its name and (optional) tag.
|
||||
|
||||
Parameters:
|
||||
aml_workspace (Workspace): aml.core Workspace that the model lives.
|
||||
model_name (str): name of the model we are looking for
|
||||
tag (str): the tag value the model was registered under.
|
||||
(optional) tag (str): the tag value & name the model was registered under.
|
||||
|
||||
Return:
|
||||
A single aml model from the workspace that matches the name and tag.
|
||||
|
@ -44,37 +44,44 @@ def get_model_by_tag(
|
|||
# Validate params. cannot be None.
|
||||
if model_name is None:
|
||||
raise ValueError("model_name[:str] is required")
|
||||
if tag_name is None:
|
||||
raise ValueError("tag_name[:str] is required")
|
||||
if tag_value is None:
|
||||
raise ValueError("tag[:str] is required")
|
||||
|
||||
if aml_workspace is None:
|
||||
print("No workspace defined - using current experiment workspace.")
|
||||
aml_workspace = get_current_workspace()
|
||||
|
||||
# get model by tag.
|
||||
model_list = AMLModel.list(
|
||||
aml_workspace, name=model_name,
|
||||
tags=[[tag_name, tag_value]], latest=True
|
||||
)
|
||||
model_list = None
|
||||
tag_ext = ""
|
||||
|
||||
# Get lastest model
|
||||
# True: by name and tags
|
||||
if tag_name is not None and tag_value is not None:
|
||||
model_list = AMLModel.list(
|
||||
aml_workspace, name=model_name,
|
||||
tags=[[tag_name, tag_value]], latest=True
|
||||
)
|
||||
tag_ext = f"tag_name: {tag_name}, tag_value: {tag_value}."
|
||||
# False: Only by name
|
||||
else:
|
||||
model_list = AMLModel.list(
|
||||
aml_workspace, name=model_name, latest=True)
|
||||
|
||||
# latest should only return 1 model, but if it does,
|
||||
# then maybe sdk or source code changed.
|
||||
should_not_happen = ("Found more than one model "
|
||||
"for the latest with {{tag_name: {tag_name},"
|
||||
"tag_value: {tag_value}. "
|
||||
"Models found: {model_list}}}")\
|
||||
.format(tag_name=tag_name, tag_value=tag_value,
|
||||
model_list=model_list)
|
||||
no_model_found = ("No Model found with {{tag_name: {tag_name} ,"
|
||||
"tag_value: {tag_value}.}}")\
|
||||
.format(tag_name=tag_name, tag_value=tag_value)
|
||||
|
||||
# define the error messages
|
||||
too_many_model_message = ("Found more than one latest model. "
|
||||
f"Models found: {model_list}. "
|
||||
f"{tag_ext}")
|
||||
|
||||
no_model_found_message = (f"No Model found with name: {model_name}. "
|
||||
f"{tag_ext}")
|
||||
|
||||
if len(model_list) > 1:
|
||||
raise ValueError(should_not_happen)
|
||||
raise ValueError(too_many_model_message)
|
||||
if len(model_list) == 1:
|
||||
return model_list[0]
|
||||
else:
|
||||
print(no_model_found)
|
||||
print(no_model_found_message)
|
||||
return None
|
||||
except Exception:
|
||||
raise
|
||||
|
|
|
@ -3,7 +3,7 @@ import sys
|
|||
import os
|
||||
from azureml.core import Run, Experiment, Workspace
|
||||
from ml_service.util.env_variables import Env
|
||||
from diabetes_regression.util.model_helper import get_model_by_tag
|
||||
from diabetes_regression.util.model_helper import get_latest_model
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -53,7 +53,7 @@ def main():
|
|||
|
||||
try:
|
||||
tag_name = 'BuildId'
|
||||
model = get_model_by_tag(
|
||||
model = get_latest_model(
|
||||
model_name, tag_name, build_id, exp.workspace)
|
||||
if (model is not None):
|
||||
print("Model was registered for this build.")
|
||||
|
|
Загрузка…
Ссылка в новой задаче