ENH: Add support for model and metadata serialization (#252)

Added tools to write config objects to YAML and read back in
This commit is contained in:
Anton Schwaighofer 2022-04-09 12:21:10 +02:00 коммит произвёл GitHub
Родитель cea7177e16
Коммит d2313f5336
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
13 изменённых файлов: 1000 добавлений и 7 удалений

Просмотреть файл

@ -7,8 +7,7 @@ outputs/
__pycache__/
.pytest_cache
.mypy_cache
logs
outputs
logs/
config.json
*.egg-info
# Temporary files generated from conda merging

Просмотреть файл

@ -11,6 +11,7 @@ from health_azure.utils import (RUN_CONTEXT, aggregate_hyperdrive_metrics, creat
fetch_run, get_most_recent_run, is_running_in_azure_ml,
set_environment_variables_for_multi_node, split_recovery_id, torch_barrier,
upload_to_datastore)
from health_azure.traverse import object_to_yaml, write_yaml_to_object
__all__ = [
"AzureRunInfo",
@ -33,5 +34,7 @@ __all__ = [
"torch_barrier",
"upload_to_datastore",
"create_crossval_hyperdrive_config",
"aggregate_hyperdrive_metrics"
"aggregate_hyperdrive_metrics",
"object_to_yaml",
"write_yaml_to_object"
]

Просмотреть файл

@ -0,0 +1,267 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import enum
import logging
from io import StringIO
from typing import Any, Dict, Iterable, Union, List, Optional
import param
from ruamel import yaml
def is_basic_type(o: Any) -> bool:
"""
Returns True if the given object is an instance of a basic simple datatype: string, integer, float.
"""
return isinstance(o, (str, int, float))
def is_enum(o: Any) -> bool:
"""
Returns True if the given object is a subclass of enum.Enum.
:param o: The object to inspect.
:return: True if the object is an enum, False otherwise.
"""
return isinstance(o, enum.Enum)
def get_all_writable_attributes(o: Any) -> Dict[str, Any]:
"""
Returns all writable attributes of an object, by resorting to the "vars" method. For object that derive
from param.Parameterized, it returns all params that are not constant and not readonly.
:param o: The object to inspect.
:return: A dictionary mapping from attribute name to its value.
"""
def _is_private(s: str) -> bool:
return s.startswith("_")
result = {}
if isinstance(o, param.Parameterized):
for param_name, p in o.params().items():
if _is_private(param_name):
logging.debug(f"get_all_writable_attributes: Skipping private field {param_name}")
elif p.constant:
logging.debug(f"get_all_writable_attributes: Skipping constant field {param_name}")
elif p.readonly:
logging.debug(f"get_all_writable_attributes: Skipping readonly field {param_name}")
else:
result[param_name] = getattr(o, param_name)
return result
try:
for name, value in vars(o).items():
if _is_private(name):
logging.debug(f"get_all_writable_attributes: Skipping private field {name}")
else:
result[name] = value
return result
except TypeError:
raise ValueError("This function can only be used on objects that support the 'vars' operation")
def all_basic_types(o: Iterable) -> bool:
"""Checks if all entries of the iterable are of a basic datatype (int, str, float).
:param o: The iterable that should be checked.
:return: True if all entries of the iterable are of a basic datatype.
"""
for item in o:
if not is_basic_type(item):
return False
return True
def _object_to_dict(o: Any) -> Union[None, int, float, str, List, Dict]:
"""
Converts an object to a dictionary mapping from attribute name to value. That value can be a dictionary recursively,
if the attribute is not a simple datatype. Lists and dictionaries are returned as-is.
:param o: The object to inspect.
:return: Returns the argument if the object is a basic datatype, otherwise a dictionary mapping from attribute
name to value.
"""
if is_basic_type(o):
return o
if isinstance(o, enum.Enum):
return o.name
if isinstance(o, list):
if not all_basic_types(o):
raise ValueError(f"Lists are only allowed to contain basic types (int, float, str), but got: {o}")
return o
if isinstance(o, dict):
if not all_basic_types(o.keys()):
raise ValueError(f"Dictionaries can only contain basic types (int, float, str) as keys, but got: {o}")
if not all_basic_types(o.values()):
raise ValueError(f"Dictionaries can only contain basic types (int, float, str) as values, but got: {o}")
return o
if o is None:
return o
try:
fields = get_all_writable_attributes(o)
return {field: _object_to_dict(value) for field, value in fields.items()}
except ValueError as ex:
raise ValueError(f"Unable to traverse object {o}: {ex}")
def object_to_dict(o: Any) -> Dict[str, Any]:
"""
Converts an object to a dictionary mapping from attribute name to value. That value can be a dictionary recursively,
if the attribute is not a simple datatype.
This function only works on objects that are not basic datatype (i.e., classes). Private fields (name starting with
underscores), or ones that appear to be constant or readonly are omitted. For attributes that are Enums, the
case name is returned as a string.
:param o: The object to inspect.
:return: Returns the argument if the object is a basic datatype, otherwise a dictionary mapping from attribute
name to value.
:raises ValueError: If the argument is a basic datatype (int, str, float)
"""
if is_basic_type(o):
raise ValueError("This function can only be used on objects that are basic datatypes.")
fields = get_all_writable_attributes(o)
result = {}
for field, value in fields.items():
logging.debug(f"object_to_dict: Processing {field}")
result[field] = _object_to_dict(value)
return result
def object_to_yaml(o: Any) -> str:
"""
Converts an object to a YAML string representation. This is done by recursively traversing all attributes and
writing them out to YAML if they are basic datatypes.
:param o: The object to inspect.
:return: A string in YAML format.
"""
return yaml.safe_dump(object_to_dict(o), default_flow_style=False) # type: ignore
def yaml_to_dict(s: str) -> Dict[str, Any]:
"""
Interprets a string as YAML and returns the contents as a dictionary.
:param s: The YAML string to parse.
:return: A dictionary where the keys are the YAML field names, and values are either the YAML leaf node values,
or dictionaries again.
"""
stream = StringIO(s)
return yaml.safe_load(stream=stream)
def _write_dict_to_object(o: Any, d: Dict[str, Any], traversed_fields: Optional[List] = None) -> List[str]:
"""
Writes a dictionary of values into an object, assuming that the attributes of the object and the dictionary keys
are in sync. For example, to write a dictionary {"foo": 1, "bar": "baz"} to an object, the object needs to have
attributes "foo" and "bar".
:param o: The object to write to.
:param d: A dictionary mapping from attribute names to values or dictionaries recursively.
:return: A list of error messages collected.
"""
issues: List[str] = []
traversed = traversed_fields or []
def report_issue(name: str, message: str) -> None:
full_field_name = ".".join(traversed + [name])
issues.append(f"Attribute {full_field_name}: {message}")
def try_set_field(name: str, value_to_write: Any) -> None:
try:
setattr(o, name, value_to_write)
except Exception as ex:
report_issue(name, f"Unable to set value {value_to_write}: {ex}")
existing_attrs = get_all_writable_attributes(o)
for name, value in existing_attrs.items():
if name in d:
value_to_write = d[name]
t_value = type(value)
t_value_to_write = type(value_to_write)
if is_basic_type(value) and is_basic_type(value_to_write):
if t_value != t_value_to_write:
report_issue(
name,
f"Skipped. Current value has type {t_value.__name__}, but trying to "
f"write {t_value_to_write.__name__}",
)
try_set_field(name, value_to_write)
elif isinstance(value, enum.Enum):
if isinstance(value_to_write, str):
try:
enum_case = getattr(t_value, value_to_write)
except Exception:
report_issue(name, f"Skipped. Enum type {t_value.__name__} has no case {value_to_write}")
else:
try_set_field(name, enum_case)
else:
report_issue(
name,
"Skipped. This is an Enum field. Can only write string values to that field "
f"(case name), but got value of type {t_value_to_write.__name__}",
)
elif value is None or value_to_write is None:
# We can't do much type checking if we get Nones. This is a potential source of errors.
try_set_field(name, value_to_write)
elif isinstance(value, List) and isinstance(value_to_write, List):
try_set_field(name, value_to_write)
elif isinstance(value, Dict) and isinstance(value_to_write, Dict):
try_set_field(name, value_to_write)
elif not is_basic_type(value) and isinstance(value_to_write, Dict):
# For anything that is not a basic datatype, we expect that we get a dictionary of fields
# recursively.
new_issues = _write_dict_to_object(
getattr(o, name), value_to_write, traversed_fields=traversed + [name]
)
issues.extend(new_issues)
else:
report_issue(
name,
f"Skipped. Current value has type {t_value.__name__}, but trying to "
f"write {t_value_to_write.__name__}",
)
else:
report_issue(name, "Present in the object, but missing in the dictionary.")
return issues
def write_dict_to_object(o: Any, d: Dict[str, Any], strict: bool = True) -> None:
"""
Writes a dictionary of values into an object, assuming that the attributes of the object and the dictionary keys
are in sync. For example, to write a dictionary {"foo": 1, "bar": "baz"} to an object, the object needs to have
attributes "foo" and "bar".
:param strict: If True, any mismatch of field names will raise a ValueError. If False, only a warning will be
printed. Note that the object may have been modified even if an error is raised.
:param o: The object to write to.
:param d: A dictionary mapping from attribute names to values or dictionaries recursively.
"""
issues = _write_dict_to_object(o, d)
if len(issues) == 0:
return
message = f"Unable to complete writing to the object: Found {len(issues)} problems. Please inspect console log."
for issue in issues:
logging.warning(issue)
if strict:
raise ValueError(message)
else:
logging.warning(message)
def write_yaml_to_object(o: Any, yaml_string: str, strict: bool = False) -> None:
"""
Writes a serialized object in YAML format back into an object, assuming that the attributes of the object and
the YAML field names are in sync.
:param strict: If True, any mismatch of field names will raise a ValueError. If False, only a warning will be
printed. Note that the object may have been modified even if an error is raised.
:param o: The object to write to.
:param yaml_string: A YAML formatted string with attribute names and values.
"""
d = yaml_to_dict(yaml_string)
write_dict_to_object(o, d, strict=strict)

Просмотреть файл

@ -72,7 +72,7 @@ ENVIRONMENT_VERSION = "1"
FINAL_MODEL_FOLDER = "final_model"
MODEL_ID_KEY_NAME = "model_id"
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
RUN_CONTEXT = Run.get_context()
RUN_CONTEXT: Run = Run.get_context()
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
WORKSPACE_CONFIG_JSON = "config.json"

Просмотреть файл

@ -0,0 +1,389 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import enum
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import param
import pytest
from _pytest.logging import LogCaptureFixture
from health_azure.traverse import (
_object_to_dict,
get_all_writable_attributes,
is_basic_type,
object_to_dict,
object_to_yaml,
yaml_to_dict,
write_dict_to_object,
write_yaml_to_object,
_write_dict_to_object,
)
@dataclass
class OptimizerConfig:
learning_rate: float = 1e-3
optimizer: Optional[str] = "Adam"
@dataclass
class TransformConfig:
blur_sigma: float = 0.1
blur_p: float = 0.2
@dataclass
class FullConfig:
optimizer: OptimizerConfig = OptimizerConfig()
transforms: TransformConfig = TransformConfig()
@dataclass
class TripleNestedConfig:
float1: float = 1.0
int1: int = 1
nested: FullConfig = FullConfig()
@dataclass
class ConfigWithList:
list_field: Union[List, Dict] = field(default_factory=list)
class ParamsConfig(param.Parameterized):
p1: int = param.Integer(default=1)
p2: str = param.String(default="foo")
p3: float = param.Number(default=2.0)
class ParamsConfigWithReadonly(param.Parameterized):
p1: int = param.Integer(default=1, readonly=True)
p2: str = param.String(default="foo", constant=True)
p3: float = param.Number(default=2.0)
_p4: float = param.Number(default=4.0)
class MyEnum(enum.Enum):
foo = "foo_value"
bar = "bar_value"
@dataclass
class ConfigWithEnum:
f1: MyEnum
def test_traverse1() -> None:
config = TransformConfig()
d = object_to_dict(config)
assert d == {"blur_sigma": 0.1, "blur_p": 0.2}
def test_traverse2() -> None:
config = FullConfig()
d = object_to_dict(config)
assert d == {
"optimizer": {"learning_rate": 1e-3, "optimizer": "Adam"},
"transforms": {"blur_sigma": 0.1, "blur_p": 0.2},
}
def test_traverse_params() -> None:
config = ParamsConfig()
d = object_to_dict(config)
assert d == {"p1": 1, "p2": "foo", "p3": 2.0}
def test_traverse_params_readonly() -> None:
"""
Private, constant and readonly fields of a Params object should be skipped.
"""
config = ParamsConfigWithReadonly()
d = get_all_writable_attributes(config)
assert d == {"p3": 2.0}
def test_traverse_list() -> None:
"""Lists of basic datatypes should be serialized"""
list_config = ConfigWithList(list_field=[1, 2.5, "foo"])
assert object_to_yaml(list_config) == """list_field:
- 1
- 2.5
- foo
"""
list_config.list_field = [1, dict()]
with pytest.raises(ValueError) as ex:
object_to_dict(list_config)
assert "only allowed to contain basic types" in str(ex)
def test_is_basic() -> None:
assert is_basic_type(1)
assert is_basic_type(1.0)
assert is_basic_type("foo")
assert not is_basic_type(None)
assert not is_basic_type(dict())
assert not is_basic_type([])
def test_traverse_dict() -> None:
"""Fields with dictionaries with basic datatypes should be serialized"""
list_config = ConfigWithList(list_field={1: "foo", "bar": 2.0})
assert object_to_yaml(list_config) == """list_field:
1: foo
bar: 2.0
"""
# Invalid dictionaries contain non-basic types either as keys or as values
for invalid in [{"foo": dict()}, {(1, 2): "foo"}]:
list_config.list_field = invalid # type: ignore
with pytest.raises(ValueError) as ex:
object_to_dict(list_config)
assert "Dictionaries can only contain basic types" in str(ex)
def test_traverse_enum() -> None:
"""
Enum objects have not non-private fields, and should return an empty value dictionary.
"""
config = MyEnum.foo
d = get_all_writable_attributes(config)
assert len(d) == 0
# Enum objects should be printed out by their case name
d = _object_to_dict(config)
assert d == "foo"
def test_traverse_none() -> None:
"""
Attributes that are None should be preserved as such in YAML.
"""
config = OptimizerConfig()
config.optimizer = None
assert _object_to_dict(None) is None
assert _object_to_dict(config) == {"learning_rate": 1e-3, "optimizer": None}
assert object_to_yaml(config) == """learning_rate: 0.001
optimizer: null
"""
def test_to_yaml_rountrip() -> None:
config = FullConfig()
yaml = object_to_yaml(config)
print("\n" + yaml)
dict = yaml_to_dict(yaml)
assert dict == object_to_dict(config)
def test_params_roundtrip() -> None:
config = ParamsConfig()
yaml = object_to_yaml(config)
print("\n" + yaml)
dict = yaml_to_dict(yaml)
assert dict == object_to_dict(config)
@pytest.mark.parametrize("my_list", [[1, "foo"], [], {1: 2.0}, {}])
def test_list_dict_roundtrip(my_list: Any) -> None:
"""Test if configs with lists and dicts can be converted to YAML and back"""
my_list = [1, "foo"]
config = ConfigWithList(list_field=my_list)
yaml = object_to_yaml(config)
print("\n" + yaml)
# Initialize the new object with a non-empty list, to see if the empty gets correctly written
config2 = ConfigWithList(list_field=["not_in_args"])
write_yaml_to_object(config2, yaml)
assert config2.list_field == my_list
def test_write_flat() -> None:
obj = OptimizerConfig()
learning_rate = 3.0
optimizer = "Foo"
d = {"learning_rate": learning_rate, "optimizer": optimizer}
write_dict_to_object(obj, d)
assert obj.learning_rate == learning_rate
assert obj.optimizer == optimizer
def test_write_nested() -> None:
obj = TripleNestedConfig()
yaml = object_to_yaml(obj)
print("\n" + yaml)
float1 = 2.0
int1 = 2
blur_p = 7.0
blur_sigma = 8.0
learning_rate = 3.0
optimizer = "Foo"
d1 = {"transforms": {"blur_sigma": blur_sigma, "blur_p": blur_p},
"optimizer": {"learning_rate": learning_rate, "optimizer": optimizer}}
d = {"float1": float1, "int1": int1, "nested": d1}
write_dict_to_object(obj, d)
assert obj.float1 == float1
assert obj.int1 == int1
assert obj.nested.optimizer.optimizer == optimizer
assert obj.nested.optimizer.learning_rate == learning_rate
assert obj.nested.transforms.blur_p == blur_p
assert obj.nested.transforms.blur_sigma == blur_sigma
from_obj = object_to_dict(obj)
assert from_obj == d
yaml = object_to_yaml(obj)
print("\n" + yaml)
obj2 = TripleNestedConfig()
write_yaml_to_object(obj2, yaml_string=yaml)
print("\n" + repr(obj2))
from_yaml_obj = object_to_dict(obj2)
assert from_yaml_obj == d
def test_yaml_and_write_roundtrip() -> None:
obj = ParamsConfig()
obj.p1 = 2
obj.p2 = "nothing"
obj.p3 = 3.14
yaml = object_to_yaml(obj)
print("\n" + yaml)
obj2 = ParamsConfig()
write_yaml_to_object(obj2, yaml_string=yaml)
assert obj2.p1 == obj.p1
assert obj2.p2 == obj.p2
assert obj2.p3 == obj.p3
def test_to_yaml_datatypes() -> None:
"""
Ensure that string fields that look like numbers are treated correctly.
"""
config = OptimizerConfig()
config.optimizer = "2"
yaml = object_to_yaml(config)
print("\n" + yaml)
dict = yaml_to_dict(yaml)
assert dict == object_to_dict(config)
def test_object_to_yaml_floats() -> None:
"""
Check that floating point numbers that could be mistaken for integers are in YAML unambiguously.
"""
config = TransformConfig()
config.blur_p = 1.0
yaml = object_to_yaml(config)
assert yaml == """blur_p: 1.0
blur_sigma: 0.1
"""
def test_write_dict_errors1(caplog: LogCaptureFixture) -> None:
"""
Check type mismatches between object and YAML contents
:return:
"""
# First test the private writer method
config = TransformConfig()
dict = {"blur_p": 1, "blur_sigma": 0.1}
errors = _write_dict_to_object(config, dict)
assert len(errors) == 1
assert "Attribute blur_p" in errors[0]
assert "Skipped" in errors[0]
assert "Current value has type float" in errors[0]
assert "trying to write int" in errors[0]
# The same error message should be raised when calling the full method, with either strict or non-strict
config = TransformConfig()
with caplog.at_level(level=logging.WARNING):
with pytest.raises(ValueError) as ex:
write_dict_to_object(config, dict, strict=True)
assert "Found 1 problems" in str(ex)
assert errors[0] in caplog.text
config = TransformConfig()
with caplog.at_level(level=logging.WARNING):
write_dict_to_object(config, dict, strict=False)
assert errors[0] in caplog.text
def test_write_dict_errors2(caplog: LogCaptureFixture) -> None:
"""
Check type mismatches between object and YAML contents, and that field names are correctly handled
:return:
"""
caplog.set_level(logging.WARNING)
config = FullConfig()
dict = object_to_dict(config)
dict["transforms"]["blur_p"] = "foo"
dict["optimizer"]["learning_rate"] = "bar"
write_dict_to_object(config, dict, strict=False)
assert "Found 2 problems" in caplog.text
assert "Attribute transforms.blur_p" in caplog.text
assert "Attribute optimizer.learning_rate" in caplog.text
def test_write_dict_errors3() -> None:
"""
Check handling of cases where the object has more fields than present in the dictionary
:return:
"""
config = TransformConfig()
dict = {"blur_p": 1.0}
issues = _write_dict_to_object(config, dict)
assert len(issues) == 1
assert "Present in the object, but missing in the dictionary" in issues[0]
assert "Attribute blur_sigma" in issues[0]
def test_write_enums() -> None:
"""
Test handling of fields that are enum types
"""
config = ConfigWithEnum(f1=MyEnum.bar)
d = object_to_dict(config)
assert d == {"f1": "bar"}
# Now change F1 to be another enum value, write back the previous dictionary, and check if F1 is back at the
# original value
config.f1 = MyEnum.foo
write_dict_to_object(config, d)
assert config.f1 == MyEnum.bar
def test_write_enums_errors() -> None:
"""
Error handling for enum fields
"""
config = ConfigWithEnum(f1=MyEnum.bar)
# Trying to write a floating point number, but we expect a string case name
issues = _write_dict_to_object(config, {"f1": 1.0})
assert len(issues) == 1
assert "Enum field" in issues[0]
assert "got value of type float" in issues[0]
# Referencing an enum case that does not exist
issues = _write_dict_to_object(config, {"f1": "no_such_case"})
assert len(issues) == 1
assert "Enum type MyEnum has no case no_such_case" in issues[0]
# This should work fine
issues = _write_dict_to_object(config, {"f1": MyEnum.bar.name})
assert len(issues) == 0
def test_write_null() -> None:
"""
Round-trip test for writing fields that are None.
"""
config = OptimizerConfig()
config.optimizer = None
dict = object_to_dict(config)
yaml = object_to_yaml(config)
config.optimizer = "foo"
issues = _write_dict_to_object(config, dict)
assert len(issues) == 0
assert config.optimizer is None
config.optimizer = "foo"
write_yaml_to_object(config, yaml)
assert config.optimizer is None

Просмотреть файл

@ -9,7 +9,7 @@ from pathlib import Path
from typing import Any, List, Optional, Tuple, TypeVar
from pytorch_lightning import Callback, Trainer, seed_everything
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
@ -134,6 +134,8 @@ def create_lightning_trainer(container: LightningContainer,
write_to_logging_info=True,
print_timestamp=False))
else:
# Use a local import here to be able to support older PL versions
from pytorch_lightning.callbacks import TQDMProgressBar
callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
# Read out additional model-specific args here.
# We probably want to keep essential ones like numgpu and logging.

Просмотреть файл

@ -5,6 +5,7 @@
from health_ml.utils.logging import AzureMLLogger, AzureMLProgressBar, log_learning_rate, log_on_epoch
from health_ml.utils.diagnostics import BatchTimeCallback
from health_ml.utils.common_utils import set_model_to_eval_mode
__all__ = [
"AzureMLLogger",
@ -12,4 +13,5 @@ __all__ = [
"BatchTimeCallback",
"log_learning_rate",
"log_on_epoch",
"set_model_to_eval_mode"
]

Просмотреть файл

@ -1,3 +1,8 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import logging
import os
import sys
@ -16,7 +21,6 @@ from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME, git_repo_root_folder,
from health_azure.utils import PathOrString, is_conda_file_with_pip_include
MAX_PATH_LENGTH = 260
# convert string to None if an empty string or whitespace is provided

Просмотреть файл

@ -1,8 +1,8 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import base64
import itertools
import mimetypes

Просмотреть файл

@ -0,0 +1,173 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import pickle
from io import BytesIO
from typing import Any, Optional, Callable, Dict, Union
import torch
from health_azure import RUN_CONTEXT, object_to_yaml, is_running_in_azure_ml
def _dump_to_stream(o: Any) -> BytesIO:
"""
Writes the given object to a byte stream in pickle format, and returns the stream.
:param o: The object to pickle.
:return: A byte stream, with the position set to 0.
"""
stream = BytesIO()
pickle.dump(o, file=stream)
stream.seek(0)
return stream
class ModelInfo:
"""Stores a model, its example input, and metadata that describes how the model was trained.
"""
MODEL = "model"
MODEL_EXAMPLE_INPUT = "model_example_input"
MODEL_CONFIG_YAML = "model_config_yaml"
GIT_REPOSITORY = "git_repository"
GIT_COMMIT_HASH = "git_commit_hash"
DATASET_NAME = "dataset_name"
AZURE_ML_WORKSPACE = "azure_ml_workspace"
AZURE_ML_RUN_ID = "azure_ml_run_id"
TEXT_TOKENIZER = "text_tokenizer"
IMAGE_PRE_PROCESSING = "image_pre_processing"
IMAGE_DIMENSIONS = "image_dimensions"
OTHER_INFO = "other_info"
OTHER_DESCRIPTION = "other_description"
def __init__(self,
model: Optional[Union[torch.nn.Module, torch.jit.ScriptModule]] = None,
model_example_input: Optional[torch.Tensor] = None,
model_config: Any = None,
git_repository: str = "",
git_commit_hash: str = "",
dataset_name: str = "",
azure_ml_workspace: str = "",
azure_ml_run_id: str = "",
text_tokenizer: Any = None,
image_pre_processing: Optional[Callable] = None,
image_dimensions: str = "",
other_info: Any = None,
other_description: str = "",
):
"""
:param model: The model that should be serialized, or the deserialized model, defaults to None
:param model_example_input: A tensor that can be input to the forward pass of the model, defaults to None
:param model_config: The configuration object that was used to start the training run, defaults to None
:param git_repository: The name of the git repository that contains the training codebase, defaults to ""
:param git_commit_hash: The git commit hash that was used to run the training, defaults to ""
:param dataset_name: The name of the dataset that was used to train the model, defaults to ""
:param azure_ml_workspace: The name of the AzureML workspace that contains the training run, defaults to ""
:param azure_ml_run_id: The AzureML run that did the training, defaults to ""
:param text_tokenizer: A text tokenizer object to pre-process the model input (default: None). The object given
here will be pickled before it is passed to ``torch.save``.
:param image_pre_processing: An object that describes the processing for the image before it is input to the
model, defaults to None
:param image_dimensions: The size of the pre-processed image that is accepted by the model, defaults to ""
:param other_info: An arbitray object that will also be written to the checkpoint. For example, this can be a
binary stream. The object provided here will be pickled before it is passed to ``torch.save``.
:param other_description: A human-readable description of what the ``other_info`` field contains.
"""
self.model = model
self.model_example_input = model_example_input
self.model_config = model_config
self.git_repository = git_repository
self.git_commit_hash = git_commit_hash
self.dataset_name = dataset_name
self.azure_ml_workspace = azure_ml_workspace
self.azure_ml_run_id = azure_ml_run_id
self.text_tokenizer = text_tokenizer
self.image_pre_processing = image_pre_processing
self.image_dimensions = image_dimensions
self.other_info = other_info
self.other_description = other_description
def get_metadata_from_azureml(self) -> None:
"""Reads information about the git repository and AzureML-related info from the AzureML run context.
If any of those information are already stored in the object, those have higher priority.
"""
if not is_running_in_azure_ml():
return
if not self.azure_ml_workspace:
self.azure_ml_workspace = RUN_CONTEXT.experiment.workspace.name
if not self.azure_ml_run_id:
self.azure_ml_run_id = RUN_CONTEXT.id
properties: Dict = RUN_CONTEXT.get_properties()
if not self.git_repository:
self.git_repository = properties.get("azureml.git.repository_uri", "")
if not self.git_commit_hash:
self.git_commit_hash = properties.get("azureml.git.commit", "")
def state_dict(self, strict: bool = True) -> Dict[str, Any]:
"""Creates a dictionary representation of the current object.
:param strict: The setting for the 'strict' flag in the call to torch.jit.trace.
"""
def bytes_or_none(o: Any) -> Optional[bytes]:
return _dump_to_stream(o).getvalue() if o is not None else None
if self.model is None or self.model_example_input is None:
raise ValueError("To generate a state dict, the model and model_example_input must be present.")
try:
traced_model = torch.jit.trace(self.model, self.model_example_input, strict=strict)
except Exception as ex:
raise ValueError(f"Unable to convert the model to torchscript: {ex}")
jit_stream = BytesIO()
torch.jit.save(traced_model, jit_stream)
try:
config_yaml = object_to_yaml(self.model_config) if self.model_config is not None else None
except Exception as ex:
raise ValueError(f"Unable to convert the model configuration to YAML: {ex}")
return {
ModelInfo.MODEL: jit_stream.getvalue(),
ModelInfo.MODEL_EXAMPLE_INPUT: self.model_example_input,
ModelInfo.MODEL_CONFIG_YAML: config_yaml,
ModelInfo.GIT_REPOSITORY: self.git_repository,
ModelInfo.GIT_COMMIT_HASH: self.git_commit_hash,
ModelInfo.DATASET_NAME: self.dataset_name,
ModelInfo.AZURE_ML_WORKSPACE: self.azure_ml_workspace,
ModelInfo.AZURE_ML_RUN_ID: self.azure_ml_run_id,
ModelInfo.TEXT_TOKENIZER: bytes_or_none(self.text_tokenizer),
ModelInfo.IMAGE_PRE_PROCESSING: bytes_or_none(self.image_pre_processing),
ModelInfo.IMAGE_DIMENSIONS: self.image_dimensions,
ModelInfo.OTHER_INFO: bytes_or_none(self.other_info),
ModelInfo.OTHER_DESCRIPTION: self.other_description
}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads a dictionary representation into the current object, overwriting all matching fields.
:param state_dict: The dictionary to load from.
"""
def unpickle_from_bytes(field: str) -> Any:
if field not in state_dict:
raise KeyError(f"State_dict does not contain a field '{field}'")
b = state_dict[field]
if b is None:
return None
stream = BytesIO(b)
try:
o = pickle.load(stream)
except Exception as ex:
raise ValueError(f"Failure when unpickling field '{field}': {ex}")
return o
self.model = torch.jit.load(BytesIO(state_dict[ModelInfo.MODEL]))
self.model_example_input = state_dict[ModelInfo.MODEL_EXAMPLE_INPUT]
self.model_config = state_dict[ModelInfo.MODEL_CONFIG_YAML]
self.git_repository = state_dict[ModelInfo.GIT_REPOSITORY]
self.git_commit_hash = state_dict[ModelInfo.GIT_COMMIT_HASH]
self.dataset_name = state_dict[ModelInfo.DATASET_NAME]
self.azure_ml_workspace = state_dict[ModelInfo.AZURE_ML_WORKSPACE]
self.azure_ml_run_id = state_dict[ModelInfo.AZURE_ML_RUN_ID]
self.text_tokenizer = unpickle_from_bytes(ModelInfo.TEXT_TOKENIZER)
self.image_pre_processing = unpickle_from_bytes(ModelInfo.IMAGE_PRE_PROCESSING)
self.image_dimensions = state_dict[ModelInfo.IMAGE_DIMENSIONS]
self.other_info = unpickle_from_bytes(ModelInfo.OTHER_INFO)
self.other_description = state_dict[ModelInfo.OTHER_DESCRIPTION]

Просмотреть файл

@ -0,0 +1,131 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Optional
from unittest import mock
import torch
from torchvision.transforms import Compose, Resize, CenterCrop
from azureml.core import Run
from health_azure import object_to_yaml, create_aml_run_object
from health_azure.utils import is_running_in_azure_ml
from health_ml.utils.serialization import ModelInfo
from testazure.utils_testazure import DEFAULT_WORKSPACE
class MyTestModule(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.max(input)
@dataclass
class MyModelConfig:
foo: str
bar: float
class MyTokenizer:
def __init__(self) -> None:
self.token = torch.tensor([3.14])
def tokenize(self, input: Any) -> torch.Tensor:
return self.token
def torch_save_and_load(o: Any) -> Any:
"""
Writes the given object via torch.save, and then loads it back in.
"""
stream = BytesIO()
torch.save(o, stream)
stream.seek(0)
return torch.load(stream)
def test_serialization_roundtrip() -> None:
"""
Test that the core Torch model can be serialized and deserialized via torch.save/load.
"""
model = MyTestModule()
example_inputs = torch.randn((2, 3))
model_output = model.forward(example_inputs)
model_config = MyModelConfig(foo="foo", bar=3.14)
tokenizer = MyTokenizer()
image_preprocessing = Compose([Resize(size=20), CenterCrop(size=10)])
example_image = torch.randn((3, 30, 30))
image_output = image_preprocessing(example_image)
other_info = b'\x01\x02'
other_description = "a byte array"
info1 = ModelInfo(
model=model,
model_example_input=example_inputs,
model_config=model_config,
text_tokenizer=tokenizer,
git_repository="repo",
git_commit_hash="hash",
dataset_name="dataset",
azure_ml_workspace="workspace",
azure_ml_run_id="run_id",
image_dimensions="dimensions",
image_pre_processing=image_preprocessing,
other_info=other_info,
other_description=other_description,
)
state_dict = torch_save_and_load(info1.state_dict())
info2 = ModelInfo()
info2.load_state_dict(state_dict)
# Test if the deserialized model gives the same output as the original model
assert isinstance(info2.model, torch.jit.ScriptModule)
assert info1.model_example_input is not None
assert info2.model_example_input is not None
assert torch.allclose(info2.model_example_input, info1.model_example_input, atol=0, rtol=0)
serialized_output = info2.model.forward(info2.model_example_input)
assert torch.allclose(serialized_output, model_output, atol=0, rtol=0)
# Tokenizer should be written as a byte stream
assert isinstance(state_dict[ModelInfo.TEXT_TOKENIZER], bytes)
assert isinstance(state_dict[ModelInfo.IMAGE_PRE_PROCESSING], bytes)
assert info2.model_config == object_to_yaml(model_config)
assert info2.git_repository == "repo"
assert info2.git_commit_hash == "hash"
assert info2.dataset_name == "dataset"
assert info2.azure_ml_workspace == "workspace"
assert info2.azure_ml_run_id == "run_id"
assert info2.image_dimensions == "dimensions"
assert info2.other_info == other_info
assert info2.other_description == other_description
# Test if the deserialized preprocessing gives the same as the original object
assert info2.image_pre_processing is not None
image_output2 = info2.image_pre_processing(example_image)
assert torch.allclose(image_output, image_output2, atol=0, rtol=0)
def test_get_metadata() -> None:
"""Test if model metadata is read correctly from the AzureML run."""
run_name = "foo"
experiment_name = "himl-tests"
run: Optional[Run] = None
try:
run = create_aml_run_object(
experiment_name=experiment_name, run_name=run_name, workspace=DEFAULT_WORKSPACE.workspace
)
assert is_running_in_azure_ml(aml_run=run)
# This ModelInfo object has no fields pre-set
model_info = ModelInfo()
# If AzureML run info is already present in the object, those fields should be preserved.
model_info2 = ModelInfo(azure_ml_run_id="foo", azure_ml_workspace="bar")
with mock.patch("health_ml.utils.serialization.RUN_CONTEXT", run):
with mock.patch("health_ml.utils.serialization.is_running_in_azure_ml", return_value=True):
model_info.get_metadata_from_azureml()
model_info2.get_metadata_from_azureml()
assert model_info.azure_ml_run_id == run.id # type: ignore
assert model_info.azure_ml_workspace == DEFAULT_WORKSPACE.workspace.name
assert model_info2.azure_ml_run_id == "foo"
assert model_info2.azure_ml_workspace == "bar"
finally:
if run is not None:
run.complete()

Просмотреть файл

@ -0,0 +1,21 @@
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from torch.nn import Module
from health_ml.utils import set_model_to_eval_mode
def test_set_to_eval_mode() -> None:
model = Module()
model.train(True)
assert model.training
with set_model_to_eval_mode(model):
assert not model.training
assert model.training
model.eval()
assert not model.training
with set_model_to_eval_mode(model):
assert not model.training
assert not model.training

Просмотреть файл

@ -24,6 +24,8 @@
"hi-ml-azure/src",
"hi-ml-azure/testazure",
"hi-ml/src",
"hi-ml-azure/src",
"hi-ml-azure/testazure",
]
},
{