This commit is contained in:
Daniel Heinze 2020-03-15 18:52:00 +01:00 коммит произвёл GitHub
Родитель 58876338ed
Коммит d531b2e3eb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 37 добавлений и 30 удалений

Просмотреть файл

@ -26,7 +26,7 @@ POSSIBILITY OF SUCH DAMAGE.
from azureml.core import Run
import argparse
import traceback
from util.model_helper import get_model_by_tag
from util.model_helper import get_latest_model
run = Run.get_context()
@ -45,7 +45,7 @@ run = Run.get_context()
# sources_dir = 'diabetes_regression'
# path_to_util = os.path.join(".", sources_dir, "util")
# sys.path.append(os.path.abspath(path_to_util)) # NOQA: E402
# from model_helper import get_model_by_tag
# from model_helper import get_latest_model
# workspace_name = os.environ.get("WORKSPACE_NAME")
# experiment_name = os.environ.get("EXPERIMENT_NAME")
# resource_group = os.environ.get("RESOURCE_GROUP")
@ -108,7 +108,7 @@ try:
firstRegistration = False
tag_name = 'experiment_name'
model = get_model_by_tag(
model = get_latest_model(
model_name, tag_name, exp.name, ws)
if (model is not None):

Просмотреть файл

@ -22,20 +22,20 @@ def get_current_workspace() -> Workspace:
return experiment.workspace
def get_model_by_tag(
def get_latest_model(
model_name: str,
tag_name: str,
tag_value: str,
tag_name: str = None,
tag_value: str = None,
aml_workspace: Workspace = None
) -> AMLModel:
"""
Retrieves and returns the latest model from the workspace
by its name and tag.
by its name and (optional) tag.
Parameters:
aml_workspace (Workspace): aml.core Workspace that the model lives.
model_name (str): name of the model we are looking for
tag (str): the tag value the model was registered under.
(optional) tag (str): the tag value & name the model was registered under.
Return:
A single aml model from the workspace that matches the name and tag.
@ -44,37 +44,44 @@ def get_model_by_tag(
# Validate params. cannot be None.
if model_name is None:
raise ValueError("model_name[:str] is required")
if tag_name is None:
raise ValueError("tag_name[:str] is required")
if tag_value is None:
raise ValueError("tag[:str] is required")
if aml_workspace is None:
print("No workspace defined - using current experiment workspace.")
aml_workspace = get_current_workspace()
# get model by tag.
model_list = None
tag_ext = ""
# Get lastest model
# True: by name and tags
if tag_name is not None and tag_value is not None:
model_list = AMLModel.list(
aml_workspace, name=model_name,
tags=[[tag_name, tag_value]], latest=True
)
tag_ext = f"tag_name: {tag_name}, tag_value: {tag_value}."
# False: Only by name
else:
model_list = AMLModel.list(
aml_workspace, name=model_name, latest=True)
# latest should only return 1 model, but if it does,
# then maybe sdk or source code changed.
should_not_happen = ("Found more than one model "
"for the latest with {{tag_name: {tag_name},"
"tag_value: {tag_value}. "
"Models found: {model_list}}}")\
.format(tag_name=tag_name, tag_value=tag_value,
model_list=model_list)
no_model_found = ("No Model found with {{tag_name: {tag_name} ,"
"tag_value: {tag_value}.}}")\
.format(tag_name=tag_name, tag_value=tag_value)
# define the error messages
too_many_model_message = ("Found more than one latest model. "
f"Models found: {model_list}. "
f"{tag_ext}")
no_model_found_message = (f"No Model found with name: {model_name}. "
f"{tag_ext}")
if len(model_list) > 1:
raise ValueError(should_not_happen)
raise ValueError(too_many_model_message)
if len(model_list) == 1:
return model_list[0]
else:
print(no_model_found)
print(no_model_found_message)
return None
except Exception:
raise

Просмотреть файл

@ -3,7 +3,7 @@ import sys
import os
from azureml.core import Run, Experiment, Workspace
from ml_service.util.env_variables import Env
from diabetes_regression.util.model_helper import get_model_by_tag
from diabetes_regression.util.model_helper import get_latest_model
def main():
@ -53,7 +53,7 @@ def main():
try:
tag_name = 'BuildId'
model = get_model_by_tag(
model = get_latest_model(
model_name, tag_name, build_id, exp.workspace)
if (model is not None):
print("Model was registered for this build.")