зеркало из https://github.com/microsoft/hi-ml.git
ENH: Add support for model and metadata serialization (#252)
Added tools to write config objects to YAML and read back in
This commit is contained in:
Родитель
cea7177e16
Коммит
d2313f5336
|
@ -7,8 +7,7 @@ outputs/
|
|||
__pycache__/
|
||||
.pytest_cache
|
||||
.mypy_cache
|
||||
logs
|
||||
outputs
|
||||
logs/
|
||||
config.json
|
||||
*.egg-info
|
||||
# Temporary files generated from conda merging
|
||||
|
|
|
@ -11,6 +11,7 @@ from health_azure.utils import (RUN_CONTEXT, aggregate_hyperdrive_metrics, creat
|
|||
fetch_run, get_most_recent_run, is_running_in_azure_ml,
|
||||
set_environment_variables_for_multi_node, split_recovery_id, torch_barrier,
|
||||
upload_to_datastore)
|
||||
from health_azure.traverse import object_to_yaml, write_yaml_to_object
|
||||
|
||||
__all__ = [
|
||||
"AzureRunInfo",
|
||||
|
@ -33,5 +34,7 @@ __all__ = [
|
|||
"torch_barrier",
|
||||
"upload_to_datastore",
|
||||
"create_crossval_hyperdrive_config",
|
||||
"aggregate_hyperdrive_metrics"
|
||||
"aggregate_hyperdrive_metrics",
|
||||
"object_to_yaml",
|
||||
"write_yaml_to_object"
|
||||
]
|
||||
|
|
|
@ -0,0 +1,267 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import enum
|
||||
import logging
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Iterable, Union, List, Optional
|
||||
|
||||
import param
|
||||
from ruamel import yaml
|
||||
|
||||
|
||||
def is_basic_type(o: Any) -> bool:
|
||||
"""
|
||||
Returns True if the given object is an instance of a basic simple datatype: string, integer, float.
|
||||
"""
|
||||
return isinstance(o, (str, int, float))
|
||||
|
||||
|
||||
def is_enum(o: Any) -> bool:
|
||||
"""
|
||||
Returns True if the given object is a subclass of enum.Enum.
|
||||
|
||||
:param o: The object to inspect.
|
||||
:return: True if the object is an enum, False otherwise.
|
||||
"""
|
||||
return isinstance(o, enum.Enum)
|
||||
|
||||
|
||||
def get_all_writable_attributes(o: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns all writable attributes of an object, by resorting to the "vars" method. For object that derive
|
||||
from param.Parameterized, it returns all params that are not constant and not readonly.
|
||||
|
||||
:param o: The object to inspect.
|
||||
:return: A dictionary mapping from attribute name to its value.
|
||||
"""
|
||||
|
||||
def _is_private(s: str) -> bool:
|
||||
return s.startswith("_")
|
||||
|
||||
result = {}
|
||||
if isinstance(o, param.Parameterized):
|
||||
for param_name, p in o.params().items():
|
||||
if _is_private(param_name):
|
||||
logging.debug(f"get_all_writable_attributes: Skipping private field {param_name}")
|
||||
elif p.constant:
|
||||
logging.debug(f"get_all_writable_attributes: Skipping constant field {param_name}")
|
||||
elif p.readonly:
|
||||
logging.debug(f"get_all_writable_attributes: Skipping readonly field {param_name}")
|
||||
else:
|
||||
result[param_name] = getattr(o, param_name)
|
||||
return result
|
||||
try:
|
||||
for name, value in vars(o).items():
|
||||
if _is_private(name):
|
||||
logging.debug(f"get_all_writable_attributes: Skipping private field {name}")
|
||||
else:
|
||||
result[name] = value
|
||||
return result
|
||||
except TypeError:
|
||||
raise ValueError("This function can only be used on objects that support the 'vars' operation")
|
||||
|
||||
|
||||
def all_basic_types(o: Iterable) -> bool:
|
||||
"""Checks if all entries of the iterable are of a basic datatype (int, str, float).
|
||||
|
||||
:param o: The iterable that should be checked.
|
||||
:return: True if all entries of the iterable are of a basic datatype.
|
||||
"""
|
||||
for item in o:
|
||||
if not is_basic_type(item):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _object_to_dict(o: Any) -> Union[None, int, float, str, List, Dict]:
|
||||
"""
|
||||
Converts an object to a dictionary mapping from attribute name to value. That value can be a dictionary recursively,
|
||||
if the attribute is not a simple datatype. Lists and dictionaries are returned as-is.
|
||||
|
||||
:param o: The object to inspect.
|
||||
:return: Returns the argument if the object is a basic datatype, otherwise a dictionary mapping from attribute
|
||||
name to value.
|
||||
"""
|
||||
if is_basic_type(o):
|
||||
return o
|
||||
if isinstance(o, enum.Enum):
|
||||
return o.name
|
||||
if isinstance(o, list):
|
||||
if not all_basic_types(o):
|
||||
raise ValueError(f"Lists are only allowed to contain basic types (int, float, str), but got: {o}")
|
||||
return o
|
||||
if isinstance(o, dict):
|
||||
if not all_basic_types(o.keys()):
|
||||
raise ValueError(f"Dictionaries can only contain basic types (int, float, str) as keys, but got: {o}")
|
||||
if not all_basic_types(o.values()):
|
||||
raise ValueError(f"Dictionaries can only contain basic types (int, float, str) as values, but got: {o}")
|
||||
return o
|
||||
if o is None:
|
||||
return o
|
||||
try:
|
||||
fields = get_all_writable_attributes(o)
|
||||
return {field: _object_to_dict(value) for field, value in fields.items()}
|
||||
except ValueError as ex:
|
||||
raise ValueError(f"Unable to traverse object {o}: {ex}")
|
||||
|
||||
|
||||
def object_to_dict(o: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts an object to a dictionary mapping from attribute name to value. That value can be a dictionary recursively,
|
||||
if the attribute is not a simple datatype.
|
||||
This function only works on objects that are not basic datatype (i.e., classes). Private fields (name starting with
|
||||
underscores), or ones that appear to be constant or readonly are omitted. For attributes that are Enums, the
|
||||
case name is returned as a string.
|
||||
|
||||
:param o: The object to inspect.
|
||||
:return: Returns the argument if the object is a basic datatype, otherwise a dictionary mapping from attribute
|
||||
name to value.
|
||||
:raises ValueError: If the argument is a basic datatype (int, str, float)
|
||||
"""
|
||||
if is_basic_type(o):
|
||||
raise ValueError("This function can only be used on objects that are basic datatypes.")
|
||||
fields = get_all_writable_attributes(o)
|
||||
result = {}
|
||||
for field, value in fields.items():
|
||||
logging.debug(f"object_to_dict: Processing {field}")
|
||||
result[field] = _object_to_dict(value)
|
||||
return result
|
||||
|
||||
|
||||
def object_to_yaml(o: Any) -> str:
|
||||
"""
|
||||
Converts an object to a YAML string representation. This is done by recursively traversing all attributes and
|
||||
writing them out to YAML if they are basic datatypes.
|
||||
|
||||
:param o: The object to inspect.
|
||||
:return: A string in YAML format.
|
||||
"""
|
||||
return yaml.safe_dump(object_to_dict(o), default_flow_style=False) # type: ignore
|
||||
|
||||
|
||||
def yaml_to_dict(s: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Interprets a string as YAML and returns the contents as a dictionary.
|
||||
|
||||
:param s: The YAML string to parse.
|
||||
:return: A dictionary where the keys are the YAML field names, and values are either the YAML leaf node values,
|
||||
or dictionaries again.
|
||||
"""
|
||||
stream = StringIO(s)
|
||||
return yaml.safe_load(stream=stream)
|
||||
|
||||
|
||||
def _write_dict_to_object(o: Any, d: Dict[str, Any], traversed_fields: Optional[List] = None) -> List[str]:
|
||||
"""
|
||||
Writes a dictionary of values into an object, assuming that the attributes of the object and the dictionary keys
|
||||
are in sync. For example, to write a dictionary {"foo": 1, "bar": "baz"} to an object, the object needs to have
|
||||
attributes "foo" and "bar".
|
||||
|
||||
:param o: The object to write to.
|
||||
:param d: A dictionary mapping from attribute names to values or dictionaries recursively.
|
||||
:return: A list of error messages collected.
|
||||
"""
|
||||
issues: List[str] = []
|
||||
traversed = traversed_fields or []
|
||||
|
||||
def report_issue(name: str, message: str) -> None:
|
||||
full_field_name = ".".join(traversed + [name])
|
||||
issues.append(f"Attribute {full_field_name}: {message}")
|
||||
|
||||
def try_set_field(name: str, value_to_write: Any) -> None:
|
||||
try:
|
||||
setattr(o, name, value_to_write)
|
||||
except Exception as ex:
|
||||
report_issue(name, f"Unable to set value {value_to_write}: {ex}")
|
||||
|
||||
existing_attrs = get_all_writable_attributes(o)
|
||||
for name, value in existing_attrs.items():
|
||||
if name in d:
|
||||
value_to_write = d[name]
|
||||
t_value = type(value)
|
||||
t_value_to_write = type(value_to_write)
|
||||
if is_basic_type(value) and is_basic_type(value_to_write):
|
||||
if t_value != t_value_to_write:
|
||||
report_issue(
|
||||
name,
|
||||
f"Skipped. Current value has type {t_value.__name__}, but trying to "
|
||||
f"write {t_value_to_write.__name__}",
|
||||
)
|
||||
try_set_field(name, value_to_write)
|
||||
elif isinstance(value, enum.Enum):
|
||||
if isinstance(value_to_write, str):
|
||||
try:
|
||||
enum_case = getattr(t_value, value_to_write)
|
||||
except Exception:
|
||||
report_issue(name, f"Skipped. Enum type {t_value.__name__} has no case {value_to_write}")
|
||||
else:
|
||||
try_set_field(name, enum_case)
|
||||
else:
|
||||
report_issue(
|
||||
name,
|
||||
"Skipped. This is an Enum field. Can only write string values to that field "
|
||||
f"(case name), but got value of type {t_value_to_write.__name__}",
|
||||
)
|
||||
elif value is None or value_to_write is None:
|
||||
# We can't do much type checking if we get Nones. This is a potential source of errors.
|
||||
try_set_field(name, value_to_write)
|
||||
elif isinstance(value, List) and isinstance(value_to_write, List):
|
||||
try_set_field(name, value_to_write)
|
||||
elif isinstance(value, Dict) and isinstance(value_to_write, Dict):
|
||||
try_set_field(name, value_to_write)
|
||||
elif not is_basic_type(value) and isinstance(value_to_write, Dict):
|
||||
# For anything that is not a basic datatype, we expect that we get a dictionary of fields
|
||||
# recursively.
|
||||
new_issues = _write_dict_to_object(
|
||||
getattr(o, name), value_to_write, traversed_fields=traversed + [name]
|
||||
)
|
||||
issues.extend(new_issues)
|
||||
else:
|
||||
report_issue(
|
||||
name,
|
||||
f"Skipped. Current value has type {t_value.__name__}, but trying to "
|
||||
f"write {t_value_to_write.__name__}",
|
||||
)
|
||||
else:
|
||||
report_issue(name, "Present in the object, but missing in the dictionary.")
|
||||
|
||||
return issues
|
||||
|
||||
|
||||
def write_dict_to_object(o: Any, d: Dict[str, Any], strict: bool = True) -> None:
|
||||
"""
|
||||
Writes a dictionary of values into an object, assuming that the attributes of the object and the dictionary keys
|
||||
are in sync. For example, to write a dictionary {"foo": 1, "bar": "baz"} to an object, the object needs to have
|
||||
attributes "foo" and "bar".
|
||||
|
||||
:param strict: If True, any mismatch of field names will raise a ValueError. If False, only a warning will be
|
||||
printed. Note that the object may have been modified even if an error is raised.
|
||||
:param o: The object to write to.
|
||||
:param d: A dictionary mapping from attribute names to values or dictionaries recursively.
|
||||
"""
|
||||
issues = _write_dict_to_object(o, d)
|
||||
if len(issues) == 0:
|
||||
return
|
||||
message = f"Unable to complete writing to the object: Found {len(issues)} problems. Please inspect console log."
|
||||
for issue in issues:
|
||||
logging.warning(issue)
|
||||
if strict:
|
||||
raise ValueError(message)
|
||||
else:
|
||||
logging.warning(message)
|
||||
|
||||
|
||||
def write_yaml_to_object(o: Any, yaml_string: str, strict: bool = False) -> None:
|
||||
"""
|
||||
Writes a serialized object in YAML format back into an object, assuming that the attributes of the object and
|
||||
the YAML field names are in sync.
|
||||
|
||||
:param strict: If True, any mismatch of field names will raise a ValueError. If False, only a warning will be
|
||||
printed. Note that the object may have been modified even if an error is raised.
|
||||
:param o: The object to write to.
|
||||
:param yaml_string: A YAML formatted string with attribute names and values.
|
||||
"""
|
||||
d = yaml_to_dict(yaml_string)
|
||||
write_dict_to_object(o, d, strict=strict)
|
|
@ -72,7 +72,7 @@ ENVIRONMENT_VERSION = "1"
|
|||
FINAL_MODEL_FOLDER = "final_model"
|
||||
MODEL_ID_KEY_NAME = "model_id"
|
||||
PYTHON_ENVIRONMENT_NAME = "python_environment_name"
|
||||
RUN_CONTEXT = Run.get_context()
|
||||
RUN_CONTEXT: Run = Run.get_context()
|
||||
PARENT_RUN_CONTEXT = getattr(RUN_CONTEXT, "parent", None)
|
||||
WORKSPACE_CONFIG_JSON = "config.json"
|
||||
|
||||
|
|
|
@ -0,0 +1,389 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import enum
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import param
|
||||
import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
|
||||
from health_azure.traverse import (
|
||||
_object_to_dict,
|
||||
get_all_writable_attributes,
|
||||
is_basic_type,
|
||||
object_to_dict,
|
||||
object_to_yaml,
|
||||
yaml_to_dict,
|
||||
write_dict_to_object,
|
||||
write_yaml_to_object,
|
||||
_write_dict_to_object,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptimizerConfig:
|
||||
learning_rate: float = 1e-3
|
||||
optimizer: Optional[str] = "Adam"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformConfig:
|
||||
blur_sigma: float = 0.1
|
||||
blur_p: float = 0.2
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullConfig:
|
||||
optimizer: OptimizerConfig = OptimizerConfig()
|
||||
transforms: TransformConfig = TransformConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TripleNestedConfig:
|
||||
float1: float = 1.0
|
||||
int1: int = 1
|
||||
nested: FullConfig = FullConfig()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigWithList:
|
||||
list_field: Union[List, Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
class ParamsConfig(param.Parameterized):
|
||||
p1: int = param.Integer(default=1)
|
||||
p2: str = param.String(default="foo")
|
||||
p3: float = param.Number(default=2.0)
|
||||
|
||||
|
||||
class ParamsConfigWithReadonly(param.Parameterized):
|
||||
p1: int = param.Integer(default=1, readonly=True)
|
||||
p2: str = param.String(default="foo", constant=True)
|
||||
p3: float = param.Number(default=2.0)
|
||||
_p4: float = param.Number(default=4.0)
|
||||
|
||||
|
||||
class MyEnum(enum.Enum):
|
||||
foo = "foo_value"
|
||||
bar = "bar_value"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigWithEnum:
|
||||
f1: MyEnum
|
||||
|
||||
|
||||
def test_traverse1() -> None:
|
||||
config = TransformConfig()
|
||||
d = object_to_dict(config)
|
||||
assert d == {"blur_sigma": 0.1, "blur_p": 0.2}
|
||||
|
||||
|
||||
def test_traverse2() -> None:
|
||||
config = FullConfig()
|
||||
d = object_to_dict(config)
|
||||
assert d == {
|
||||
"optimizer": {"learning_rate": 1e-3, "optimizer": "Adam"},
|
||||
"transforms": {"blur_sigma": 0.1, "blur_p": 0.2},
|
||||
}
|
||||
|
||||
|
||||
def test_traverse_params() -> None:
|
||||
config = ParamsConfig()
|
||||
d = object_to_dict(config)
|
||||
assert d == {"p1": 1, "p2": "foo", "p3": 2.0}
|
||||
|
||||
|
||||
def test_traverse_params_readonly() -> None:
|
||||
"""
|
||||
Private, constant and readonly fields of a Params object should be skipped.
|
||||
"""
|
||||
config = ParamsConfigWithReadonly()
|
||||
d = get_all_writable_attributes(config)
|
||||
assert d == {"p3": 2.0}
|
||||
|
||||
|
||||
def test_traverse_list() -> None:
|
||||
"""Lists of basic datatypes should be serialized"""
|
||||
list_config = ConfigWithList(list_field=[1, 2.5, "foo"])
|
||||
assert object_to_yaml(list_config) == """list_field:
|
||||
- 1
|
||||
- 2.5
|
||||
- foo
|
||||
"""
|
||||
list_config.list_field = [1, dict()]
|
||||
with pytest.raises(ValueError) as ex:
|
||||
object_to_dict(list_config)
|
||||
assert "only allowed to contain basic types" in str(ex)
|
||||
|
||||
|
||||
def test_is_basic() -> None:
|
||||
assert is_basic_type(1)
|
||||
assert is_basic_type(1.0)
|
||||
assert is_basic_type("foo")
|
||||
assert not is_basic_type(None)
|
||||
assert not is_basic_type(dict())
|
||||
assert not is_basic_type([])
|
||||
|
||||
|
||||
def test_traverse_dict() -> None:
|
||||
"""Fields with dictionaries with basic datatypes should be serialized"""
|
||||
list_config = ConfigWithList(list_field={1: "foo", "bar": 2.0})
|
||||
assert object_to_yaml(list_config) == """list_field:
|
||||
1: foo
|
||||
bar: 2.0
|
||||
"""
|
||||
# Invalid dictionaries contain non-basic types either as keys or as values
|
||||
for invalid in [{"foo": dict()}, {(1, 2): "foo"}]:
|
||||
list_config.list_field = invalid # type: ignore
|
||||
with pytest.raises(ValueError) as ex:
|
||||
object_to_dict(list_config)
|
||||
assert "Dictionaries can only contain basic types" in str(ex)
|
||||
|
||||
|
||||
def test_traverse_enum() -> None:
|
||||
"""
|
||||
Enum objects have not non-private fields, and should return an empty value dictionary.
|
||||
"""
|
||||
config = MyEnum.foo
|
||||
d = get_all_writable_attributes(config)
|
||||
assert len(d) == 0
|
||||
# Enum objects should be printed out by their case name
|
||||
d = _object_to_dict(config)
|
||||
assert d == "foo"
|
||||
|
||||
|
||||
def test_traverse_none() -> None:
|
||||
"""
|
||||
Attributes that are None should be preserved as such in YAML.
|
||||
"""
|
||||
config = OptimizerConfig()
|
||||
config.optimizer = None
|
||||
assert _object_to_dict(None) is None
|
||||
assert _object_to_dict(config) == {"learning_rate": 1e-3, "optimizer": None}
|
||||
assert object_to_yaml(config) == """learning_rate: 0.001
|
||||
optimizer: null
|
||||
"""
|
||||
|
||||
|
||||
def test_to_yaml_rountrip() -> None:
|
||||
config = FullConfig()
|
||||
yaml = object_to_yaml(config)
|
||||
print("\n" + yaml)
|
||||
dict = yaml_to_dict(yaml)
|
||||
assert dict == object_to_dict(config)
|
||||
|
||||
|
||||
def test_params_roundtrip() -> None:
|
||||
config = ParamsConfig()
|
||||
yaml = object_to_yaml(config)
|
||||
print("\n" + yaml)
|
||||
dict = yaml_to_dict(yaml)
|
||||
assert dict == object_to_dict(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("my_list", [[1, "foo"], [], {1: 2.0}, {}])
|
||||
def test_list_dict_roundtrip(my_list: Any) -> None:
|
||||
"""Test if configs with lists and dicts can be converted to YAML and back"""
|
||||
my_list = [1, "foo"]
|
||||
config = ConfigWithList(list_field=my_list)
|
||||
yaml = object_to_yaml(config)
|
||||
print("\n" + yaml)
|
||||
# Initialize the new object with a non-empty list, to see if the empty gets correctly written
|
||||
config2 = ConfigWithList(list_field=["not_in_args"])
|
||||
write_yaml_to_object(config2, yaml)
|
||||
assert config2.list_field == my_list
|
||||
|
||||
|
||||
def test_write_flat() -> None:
|
||||
obj = OptimizerConfig()
|
||||
learning_rate = 3.0
|
||||
optimizer = "Foo"
|
||||
d = {"learning_rate": learning_rate, "optimizer": optimizer}
|
||||
write_dict_to_object(obj, d)
|
||||
assert obj.learning_rate == learning_rate
|
||||
assert obj.optimizer == optimizer
|
||||
|
||||
|
||||
def test_write_nested() -> None:
|
||||
obj = TripleNestedConfig()
|
||||
yaml = object_to_yaml(obj)
|
||||
print("\n" + yaml)
|
||||
float1 = 2.0
|
||||
int1 = 2
|
||||
blur_p = 7.0
|
||||
blur_sigma = 8.0
|
||||
learning_rate = 3.0
|
||||
optimizer = "Foo"
|
||||
d1 = {"transforms": {"blur_sigma": blur_sigma, "blur_p": blur_p},
|
||||
"optimizer": {"learning_rate": learning_rate, "optimizer": optimizer}}
|
||||
d = {"float1": float1, "int1": int1, "nested": d1}
|
||||
write_dict_to_object(obj, d)
|
||||
assert obj.float1 == float1
|
||||
assert obj.int1 == int1
|
||||
assert obj.nested.optimizer.optimizer == optimizer
|
||||
assert obj.nested.optimizer.learning_rate == learning_rate
|
||||
assert obj.nested.transforms.blur_p == blur_p
|
||||
assert obj.nested.transforms.blur_sigma == blur_sigma
|
||||
|
||||
from_obj = object_to_dict(obj)
|
||||
assert from_obj == d
|
||||
|
||||
yaml = object_to_yaml(obj)
|
||||
|
||||
print("\n" + yaml)
|
||||
obj2 = TripleNestedConfig()
|
||||
write_yaml_to_object(obj2, yaml_string=yaml)
|
||||
print("\n" + repr(obj2))
|
||||
from_yaml_obj = object_to_dict(obj2)
|
||||
assert from_yaml_obj == d
|
||||
|
||||
|
||||
def test_yaml_and_write_roundtrip() -> None:
|
||||
obj = ParamsConfig()
|
||||
obj.p1 = 2
|
||||
obj.p2 = "nothing"
|
||||
obj.p3 = 3.14
|
||||
yaml = object_to_yaml(obj)
|
||||
print("\n" + yaml)
|
||||
obj2 = ParamsConfig()
|
||||
write_yaml_to_object(obj2, yaml_string=yaml)
|
||||
assert obj2.p1 == obj.p1
|
||||
assert obj2.p2 == obj.p2
|
||||
assert obj2.p3 == obj.p3
|
||||
|
||||
|
||||
def test_to_yaml_datatypes() -> None:
|
||||
"""
|
||||
Ensure that string fields that look like numbers are treated correctly.
|
||||
"""
|
||||
config = OptimizerConfig()
|
||||
config.optimizer = "2"
|
||||
yaml = object_to_yaml(config)
|
||||
print("\n" + yaml)
|
||||
dict = yaml_to_dict(yaml)
|
||||
assert dict == object_to_dict(config)
|
||||
|
||||
|
||||
def test_object_to_yaml_floats() -> None:
|
||||
"""
|
||||
Check that floating point numbers that could be mistaken for integers are in YAML unambiguously.
|
||||
"""
|
||||
config = TransformConfig()
|
||||
config.blur_p = 1.0
|
||||
yaml = object_to_yaml(config)
|
||||
assert yaml == """blur_p: 1.0
|
||||
blur_sigma: 0.1
|
||||
"""
|
||||
|
||||
|
||||
def test_write_dict_errors1(caplog: LogCaptureFixture) -> None:
|
||||
"""
|
||||
Check type mismatches between object and YAML contents
|
||||
:return:
|
||||
"""
|
||||
# First test the private writer method
|
||||
config = TransformConfig()
|
||||
dict = {"blur_p": 1, "blur_sigma": 0.1}
|
||||
errors = _write_dict_to_object(config, dict)
|
||||
assert len(errors) == 1
|
||||
assert "Attribute blur_p" in errors[0]
|
||||
assert "Skipped" in errors[0]
|
||||
assert "Current value has type float" in errors[0]
|
||||
assert "trying to write int" in errors[0]
|
||||
|
||||
# The same error message should be raised when calling the full method, with either strict or non-strict
|
||||
config = TransformConfig()
|
||||
with caplog.at_level(level=logging.WARNING):
|
||||
with pytest.raises(ValueError) as ex:
|
||||
write_dict_to_object(config, dict, strict=True)
|
||||
assert "Found 1 problems" in str(ex)
|
||||
assert errors[0] in caplog.text
|
||||
|
||||
config = TransformConfig()
|
||||
with caplog.at_level(level=logging.WARNING):
|
||||
write_dict_to_object(config, dict, strict=False)
|
||||
assert errors[0] in caplog.text
|
||||
|
||||
|
||||
def test_write_dict_errors2(caplog: LogCaptureFixture) -> None:
|
||||
"""
|
||||
Check type mismatches between object and YAML contents, and that field names are correctly handled
|
||||
:return:
|
||||
"""
|
||||
caplog.set_level(logging.WARNING)
|
||||
config = FullConfig()
|
||||
dict = object_to_dict(config)
|
||||
dict["transforms"]["blur_p"] = "foo"
|
||||
dict["optimizer"]["learning_rate"] = "bar"
|
||||
write_dict_to_object(config, dict, strict=False)
|
||||
assert "Found 2 problems" in caplog.text
|
||||
assert "Attribute transforms.blur_p" in caplog.text
|
||||
assert "Attribute optimizer.learning_rate" in caplog.text
|
||||
|
||||
|
||||
def test_write_dict_errors3() -> None:
|
||||
"""
|
||||
Check handling of cases where the object has more fields than present in the dictionary
|
||||
:return:
|
||||
"""
|
||||
config = TransformConfig()
|
||||
dict = {"blur_p": 1.0}
|
||||
issues = _write_dict_to_object(config, dict)
|
||||
assert len(issues) == 1
|
||||
assert "Present in the object, but missing in the dictionary" in issues[0]
|
||||
assert "Attribute blur_sigma" in issues[0]
|
||||
|
||||
|
||||
def test_write_enums() -> None:
|
||||
"""
|
||||
Test handling of fields that are enum types
|
||||
"""
|
||||
config = ConfigWithEnum(f1=MyEnum.bar)
|
||||
d = object_to_dict(config)
|
||||
assert d == {"f1": "bar"}
|
||||
# Now change F1 to be another enum value, write back the previous dictionary, and check if F1 is back at the
|
||||
# original value
|
||||
config.f1 = MyEnum.foo
|
||||
write_dict_to_object(config, d)
|
||||
assert config.f1 == MyEnum.bar
|
||||
|
||||
|
||||
def test_write_enums_errors() -> None:
|
||||
"""
|
||||
Error handling for enum fields
|
||||
"""
|
||||
config = ConfigWithEnum(f1=MyEnum.bar)
|
||||
# Trying to write a floating point number, but we expect a string case name
|
||||
issues = _write_dict_to_object(config, {"f1": 1.0})
|
||||
assert len(issues) == 1
|
||||
assert "Enum field" in issues[0]
|
||||
assert "got value of type float" in issues[0]
|
||||
# Referencing an enum case that does not exist
|
||||
issues = _write_dict_to_object(config, {"f1": "no_such_case"})
|
||||
assert len(issues) == 1
|
||||
assert "Enum type MyEnum has no case no_such_case" in issues[0]
|
||||
# This should work fine
|
||||
issues = _write_dict_to_object(config, {"f1": MyEnum.bar.name})
|
||||
assert len(issues) == 0
|
||||
|
||||
|
||||
def test_write_null() -> None:
|
||||
"""
|
||||
Round-trip test for writing fields that are None.
|
||||
"""
|
||||
config = OptimizerConfig()
|
||||
config.optimizer = None
|
||||
dict = object_to_dict(config)
|
||||
yaml = object_to_yaml(config)
|
||||
config.optimizer = "foo"
|
||||
issues = _write_dict_to_object(config, dict)
|
||||
assert len(issues) == 0
|
||||
assert config.optimizer is None
|
||||
config.optimizer = "foo"
|
||||
write_yaml_to_object(config, yaml)
|
||||
assert config.optimizer is None
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
from typing import Any, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pytorch_lightning import Callback, Trainer, seed_everything
|
||||
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.callbacks import GPUStatsMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
||||
|
@ -134,6 +134,8 @@ def create_lightning_trainer(container: LightningContainer,
|
|||
write_to_logging_info=True,
|
||||
print_timestamp=False))
|
||||
else:
|
||||
# Use a local import here to be able to support older PL versions
|
||||
from pytorch_lightning.callbacks import TQDMProgressBar
|
||||
callbacks.append(TQDMProgressBar(refresh_rate=progress_bar_refresh_rate))
|
||||
# Read out additional model-specific args here.
|
||||
# We probably want to keep essential ones like numgpu and logging.
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
|
||||
from health_ml.utils.logging import AzureMLLogger, AzureMLProgressBar, log_learning_rate, log_on_epoch
|
||||
from health_ml.utils.diagnostics import BatchTimeCallback
|
||||
from health_ml.utils.common_utils import set_model_to_eval_mode
|
||||
|
||||
__all__ = [
|
||||
"AzureMLLogger",
|
||||
|
@ -12,4 +13,5 @@ __all__ = [
|
|||
"BatchTimeCallback",
|
||||
"log_learning_rate",
|
||||
"log_on_epoch",
|
||||
"set_model_to_eval_mode"
|
||||
]
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
@ -16,7 +21,6 @@ from health_azure.paths import ENVIRONMENT_YAML_FILE_NAME, git_repo_root_folder,
|
|||
|
||||
from health_azure.utils import PathOrString, is_conda_file_with_pip_include
|
||||
|
||||
|
||||
MAX_PATH_LENGTH = 260
|
||||
|
||||
# convert string to None if an empty string or whitespace is provided
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import base64
|
||||
import itertools
|
||||
import mimetypes
|
||||
|
|
|
@ -0,0 +1,173 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import pickle
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from health_azure import RUN_CONTEXT, object_to_yaml, is_running_in_azure_ml
|
||||
|
||||
|
||||
def _dump_to_stream(o: Any) -> BytesIO:
|
||||
"""
|
||||
Writes the given object to a byte stream in pickle format, and returns the stream.
|
||||
|
||||
:param o: The object to pickle.
|
||||
:return: A byte stream, with the position set to 0.
|
||||
"""
|
||||
stream = BytesIO()
|
||||
pickle.dump(o, file=stream)
|
||||
stream.seek(0)
|
||||
return stream
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
"""Stores a model, its example input, and metadata that describes how the model was trained.
|
||||
"""
|
||||
MODEL = "model"
|
||||
MODEL_EXAMPLE_INPUT = "model_example_input"
|
||||
MODEL_CONFIG_YAML = "model_config_yaml"
|
||||
GIT_REPOSITORY = "git_repository"
|
||||
GIT_COMMIT_HASH = "git_commit_hash"
|
||||
DATASET_NAME = "dataset_name"
|
||||
AZURE_ML_WORKSPACE = "azure_ml_workspace"
|
||||
AZURE_ML_RUN_ID = "azure_ml_run_id"
|
||||
TEXT_TOKENIZER = "text_tokenizer"
|
||||
IMAGE_PRE_PROCESSING = "image_pre_processing"
|
||||
IMAGE_DIMENSIONS = "image_dimensions"
|
||||
OTHER_INFO = "other_info"
|
||||
OTHER_DESCRIPTION = "other_description"
|
||||
|
||||
def __init__(self,
|
||||
model: Optional[Union[torch.nn.Module, torch.jit.ScriptModule]] = None,
|
||||
model_example_input: Optional[torch.Tensor] = None,
|
||||
model_config: Any = None,
|
||||
git_repository: str = "",
|
||||
git_commit_hash: str = "",
|
||||
dataset_name: str = "",
|
||||
azure_ml_workspace: str = "",
|
||||
azure_ml_run_id: str = "",
|
||||
text_tokenizer: Any = None,
|
||||
image_pre_processing: Optional[Callable] = None,
|
||||
image_dimensions: str = "",
|
||||
other_info: Any = None,
|
||||
other_description: str = "",
|
||||
):
|
||||
"""
|
||||
:param model: The model that should be serialized, or the deserialized model, defaults to None
|
||||
:param model_example_input: A tensor that can be input to the forward pass of the model, defaults to None
|
||||
:param model_config: The configuration object that was used to start the training run, defaults to None
|
||||
:param git_repository: The name of the git repository that contains the training codebase, defaults to ""
|
||||
:param git_commit_hash: The git commit hash that was used to run the training, defaults to ""
|
||||
:param dataset_name: The name of the dataset that was used to train the model, defaults to ""
|
||||
:param azure_ml_workspace: The name of the AzureML workspace that contains the training run, defaults to ""
|
||||
:param azure_ml_run_id: The AzureML run that did the training, defaults to ""
|
||||
:param text_tokenizer: A text tokenizer object to pre-process the model input (default: None). The object given
|
||||
here will be pickled before it is passed to ``torch.save``.
|
||||
:param image_pre_processing: An object that describes the processing for the image before it is input to the
|
||||
model, defaults to None
|
||||
:param image_dimensions: The size of the pre-processed image that is accepted by the model, defaults to ""
|
||||
:param other_info: An arbitray object that will also be written to the checkpoint. For example, this can be a
|
||||
binary stream. The object provided here will be pickled before it is passed to ``torch.save``.
|
||||
:param other_description: A human-readable description of what the ``other_info`` field contains.
|
||||
"""
|
||||
self.model = model
|
||||
self.model_example_input = model_example_input
|
||||
self.model_config = model_config
|
||||
self.git_repository = git_repository
|
||||
self.git_commit_hash = git_commit_hash
|
||||
self.dataset_name = dataset_name
|
||||
self.azure_ml_workspace = azure_ml_workspace
|
||||
self.azure_ml_run_id = azure_ml_run_id
|
||||
self.text_tokenizer = text_tokenizer
|
||||
self.image_pre_processing = image_pre_processing
|
||||
self.image_dimensions = image_dimensions
|
||||
self.other_info = other_info
|
||||
self.other_description = other_description
|
||||
|
||||
def get_metadata_from_azureml(self) -> None:
|
||||
"""Reads information about the git repository and AzureML-related info from the AzureML run context.
|
||||
If any of those information are already stored in the object, those have higher priority.
|
||||
"""
|
||||
if not is_running_in_azure_ml():
|
||||
return
|
||||
if not self.azure_ml_workspace:
|
||||
self.azure_ml_workspace = RUN_CONTEXT.experiment.workspace.name
|
||||
if not self.azure_ml_run_id:
|
||||
self.azure_ml_run_id = RUN_CONTEXT.id
|
||||
properties: Dict = RUN_CONTEXT.get_properties()
|
||||
if not self.git_repository:
|
||||
self.git_repository = properties.get("azureml.git.repository_uri", "")
|
||||
if not self.git_commit_hash:
|
||||
self.git_commit_hash = properties.get("azureml.git.commit", "")
|
||||
|
||||
def state_dict(self, strict: bool = True) -> Dict[str, Any]:
|
||||
"""Creates a dictionary representation of the current object.
|
||||
|
||||
:param strict: The setting for the 'strict' flag in the call to torch.jit.trace.
|
||||
"""
|
||||
def bytes_or_none(o: Any) -> Optional[bytes]:
|
||||
return _dump_to_stream(o).getvalue() if o is not None else None
|
||||
|
||||
if self.model is None or self.model_example_input is None:
|
||||
raise ValueError("To generate a state dict, the model and model_example_input must be present.")
|
||||
try:
|
||||
traced_model = torch.jit.trace(self.model, self.model_example_input, strict=strict)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Unable to convert the model to torchscript: {ex}")
|
||||
jit_stream = BytesIO()
|
||||
torch.jit.save(traced_model, jit_stream)
|
||||
try:
|
||||
config_yaml = object_to_yaml(self.model_config) if self.model_config is not None else None
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Unable to convert the model configuration to YAML: {ex}")
|
||||
return {
|
||||
ModelInfo.MODEL: jit_stream.getvalue(),
|
||||
ModelInfo.MODEL_EXAMPLE_INPUT: self.model_example_input,
|
||||
ModelInfo.MODEL_CONFIG_YAML: config_yaml,
|
||||
ModelInfo.GIT_REPOSITORY: self.git_repository,
|
||||
ModelInfo.GIT_COMMIT_HASH: self.git_commit_hash,
|
||||
ModelInfo.DATASET_NAME: self.dataset_name,
|
||||
ModelInfo.AZURE_ML_WORKSPACE: self.azure_ml_workspace,
|
||||
ModelInfo.AZURE_ML_RUN_ID: self.azure_ml_run_id,
|
||||
ModelInfo.TEXT_TOKENIZER: bytes_or_none(self.text_tokenizer),
|
||||
ModelInfo.IMAGE_PRE_PROCESSING: bytes_or_none(self.image_pre_processing),
|
||||
ModelInfo.IMAGE_DIMENSIONS: self.image_dimensions,
|
||||
ModelInfo.OTHER_INFO: bytes_or_none(self.other_info),
|
||||
ModelInfo.OTHER_DESCRIPTION: self.other_description
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
"""Loads a dictionary representation into the current object, overwriting all matching fields.
|
||||
|
||||
:param state_dict: The dictionary to load from.
|
||||
"""
|
||||
def unpickle_from_bytes(field: str) -> Any:
|
||||
if field not in state_dict:
|
||||
raise KeyError(f"State_dict does not contain a field '{field}'")
|
||||
b = state_dict[field]
|
||||
if b is None:
|
||||
return None
|
||||
stream = BytesIO(b)
|
||||
try:
|
||||
o = pickle.load(stream)
|
||||
except Exception as ex:
|
||||
raise ValueError(f"Failure when unpickling field '{field}': {ex}")
|
||||
return o
|
||||
|
||||
self.model = torch.jit.load(BytesIO(state_dict[ModelInfo.MODEL]))
|
||||
self.model_example_input = state_dict[ModelInfo.MODEL_EXAMPLE_INPUT]
|
||||
self.model_config = state_dict[ModelInfo.MODEL_CONFIG_YAML]
|
||||
self.git_repository = state_dict[ModelInfo.GIT_REPOSITORY]
|
||||
self.git_commit_hash = state_dict[ModelInfo.GIT_COMMIT_HASH]
|
||||
self.dataset_name = state_dict[ModelInfo.DATASET_NAME]
|
||||
self.azure_ml_workspace = state_dict[ModelInfo.AZURE_ML_WORKSPACE]
|
||||
self.azure_ml_run_id = state_dict[ModelInfo.AZURE_ML_RUN_ID]
|
||||
self.text_tokenizer = unpickle_from_bytes(ModelInfo.TEXT_TOKENIZER)
|
||||
self.image_pre_processing = unpickle_from_bytes(ModelInfo.IMAGE_PRE_PROCESSING)
|
||||
self.image_dimensions = state_dict[ModelInfo.IMAGE_DIMENSIONS]
|
||||
self.other_info = unpickle_from_bytes(ModelInfo.OTHER_INFO)
|
||||
self.other_description = state_dict[ModelInfo.OTHER_DESCRIPTION]
|
|
@ -0,0 +1,131 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop
|
||||
from azureml.core import Run
|
||||
|
||||
from health_azure import object_to_yaml, create_aml_run_object
|
||||
from health_azure.utils import is_running_in_azure_ml
|
||||
from health_ml.utils.serialization import ModelInfo
|
||||
from testazure.utils_testazure import DEFAULT_WORKSPACE
|
||||
|
||||
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return torch.max(input)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MyModelConfig:
|
||||
foo: str
|
||||
bar: float
|
||||
|
||||
|
||||
class MyTokenizer:
|
||||
def __init__(self) -> None:
|
||||
self.token = torch.tensor([3.14])
|
||||
|
||||
def tokenize(self, input: Any) -> torch.Tensor:
|
||||
return self.token
|
||||
|
||||
|
||||
def torch_save_and_load(o: Any) -> Any:
|
||||
"""
|
||||
Writes the given object via torch.save, and then loads it back in.
|
||||
"""
|
||||
stream = BytesIO()
|
||||
torch.save(o, stream)
|
||||
stream.seek(0)
|
||||
return torch.load(stream)
|
||||
|
||||
|
||||
def test_serialization_roundtrip() -> None:
|
||||
"""
|
||||
Test that the core Torch model can be serialized and deserialized via torch.save/load.
|
||||
"""
|
||||
model = MyTestModule()
|
||||
example_inputs = torch.randn((2, 3))
|
||||
model_output = model.forward(example_inputs)
|
||||
model_config = MyModelConfig(foo="foo", bar=3.14)
|
||||
tokenizer = MyTokenizer()
|
||||
image_preprocessing = Compose([Resize(size=20), CenterCrop(size=10)])
|
||||
example_image = torch.randn((3, 30, 30))
|
||||
image_output = image_preprocessing(example_image)
|
||||
other_info = b'\x01\x02'
|
||||
other_description = "a byte array"
|
||||
info1 = ModelInfo(
|
||||
model=model,
|
||||
model_example_input=example_inputs,
|
||||
model_config=model_config,
|
||||
text_tokenizer=tokenizer,
|
||||
git_repository="repo",
|
||||
git_commit_hash="hash",
|
||||
dataset_name="dataset",
|
||||
azure_ml_workspace="workspace",
|
||||
azure_ml_run_id="run_id",
|
||||
image_dimensions="dimensions",
|
||||
image_pre_processing=image_preprocessing,
|
||||
other_info=other_info,
|
||||
other_description=other_description,
|
||||
)
|
||||
|
||||
state_dict = torch_save_and_load(info1.state_dict())
|
||||
info2 = ModelInfo()
|
||||
info2.load_state_dict(state_dict)
|
||||
# Test if the deserialized model gives the same output as the original model
|
||||
assert isinstance(info2.model, torch.jit.ScriptModule)
|
||||
assert info1.model_example_input is not None
|
||||
assert info2.model_example_input is not None
|
||||
assert torch.allclose(info2.model_example_input, info1.model_example_input, atol=0, rtol=0)
|
||||
serialized_output = info2.model.forward(info2.model_example_input)
|
||||
assert torch.allclose(serialized_output, model_output, atol=0, rtol=0)
|
||||
# Tokenizer should be written as a byte stream
|
||||
assert isinstance(state_dict[ModelInfo.TEXT_TOKENIZER], bytes)
|
||||
assert isinstance(state_dict[ModelInfo.IMAGE_PRE_PROCESSING], bytes)
|
||||
assert info2.model_config == object_to_yaml(model_config)
|
||||
assert info2.git_repository == "repo"
|
||||
assert info2.git_commit_hash == "hash"
|
||||
assert info2.dataset_name == "dataset"
|
||||
assert info2.azure_ml_workspace == "workspace"
|
||||
assert info2.azure_ml_run_id == "run_id"
|
||||
assert info2.image_dimensions == "dimensions"
|
||||
assert info2.other_info == other_info
|
||||
assert info2.other_description == other_description
|
||||
# Test if the deserialized preprocessing gives the same as the original object
|
||||
assert info2.image_pre_processing is not None
|
||||
image_output2 = info2.image_pre_processing(example_image)
|
||||
assert torch.allclose(image_output, image_output2, atol=0, rtol=0)
|
||||
|
||||
|
||||
def test_get_metadata() -> None:
|
||||
"""Test if model metadata is read correctly from the AzureML run."""
|
||||
run_name = "foo"
|
||||
experiment_name = "himl-tests"
|
||||
run: Optional[Run] = None
|
||||
try:
|
||||
run = create_aml_run_object(
|
||||
experiment_name=experiment_name, run_name=run_name, workspace=DEFAULT_WORKSPACE.workspace
|
||||
)
|
||||
assert is_running_in_azure_ml(aml_run=run)
|
||||
# This ModelInfo object has no fields pre-set
|
||||
model_info = ModelInfo()
|
||||
# If AzureML run info is already present in the object, those fields should be preserved.
|
||||
model_info2 = ModelInfo(azure_ml_run_id="foo", azure_ml_workspace="bar")
|
||||
with mock.patch("health_ml.utils.serialization.RUN_CONTEXT", run):
|
||||
with mock.patch("health_ml.utils.serialization.is_running_in_azure_ml", return_value=True):
|
||||
model_info.get_metadata_from_azureml()
|
||||
model_info2.get_metadata_from_azureml()
|
||||
assert model_info.azure_ml_run_id == run.id # type: ignore
|
||||
assert model_info.azure_ml_workspace == DEFAULT_WORKSPACE.workspace.name
|
||||
assert model_info2.azure_ml_run_id == "foo"
|
||||
assert model_info2.azure_ml_workspace == "bar"
|
||||
finally:
|
||||
if run is not None:
|
||||
run.complete()
|
|
@ -0,0 +1,21 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
from torch.nn import Module
|
||||
from health_ml.utils import set_model_to_eval_mode
|
||||
|
||||
|
||||
def test_set_to_eval_mode() -> None:
|
||||
model = Module()
|
||||
model.train(True)
|
||||
assert model.training
|
||||
with set_model_to_eval_mode(model):
|
||||
assert not model.training
|
||||
assert model.training
|
||||
|
||||
model.eval()
|
||||
assert not model.training
|
||||
with set_model_to_eval_mode(model):
|
||||
assert not model.training
|
||||
assert not model.training
|
|
@ -24,6 +24,8 @@
|
|||
"hi-ml-azure/src",
|
||||
"hi-ml-azure/testazure",
|
||||
"hi-ml/src",
|
||||
"hi-ml-azure/src",
|
||||
"hi-ml-azure/testazure",
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Загрузка…
Ссылка в новой задаче