schema: refactoring to provide serializing method

The original way is one line, but not straightforward, and need type
ignore annotation. With the methods, the readability is improved.
This commit is contained in:
Chi Song 2021-08-17 17:07:43 -07:00 коммит произвёл Chi Song
Родитель 2e9cd41342
Коммит bd133486bb
15 изменённых файлов: 77 добавлений и 72 удалений

Просмотреть файл

@ -24,9 +24,7 @@ def run(args: Namespace) -> int:
notifier_data = builder.partial_resolve(constants.NOTIFIER)
if notifier_data:
notifier_runbook = schema.Notifier.schema().load( # type: ignore
notifier_data, many=True
)
notifier_runbook = schema.load_by_type_many(schema.Notifier, notifier_data)
notifier.initialize(runbooks=notifier_runbook)
run_message = notifier.TestRunMessage(
test_project=builder.partial_resolve(constants.TEST_PROJECT),

Просмотреть файл

@ -226,7 +226,7 @@ class RunbookBuilder:
and data_from_include[constants.VARIABLE]
):
variables_from_include = [
schema.Variable.schema().load(variable) # type: ignore
schema.load_by_type(schema.Variable, variable)
for variable in data_from_include[constants.VARIABLE]
]
# resolve to absolute path
@ -238,7 +238,7 @@ class RunbookBuilder:
and data_from_current[constants.VARIABLE]
):
variables_from_current: List[schema.Variable] = [
schema.Variable.schema().load(variable) # type: ignore
schema.load_by_type(schema.Variable, variable)
for variable in data_from_current[constants.VARIABLE]
]
@ -345,7 +345,7 @@ class RunbookBuilder:
for include_raw in includes:
try:
include: schema.Include
include = schema.Include.schema().load(include_raw) # type: ignore
include = schema.load_by_type(schema.Include, include_raw)
except Exception as identifer:
raise LisaException(
f"error on loading include node [{include_raw}]: {identifer}"

Просмотреть файл

@ -195,8 +195,6 @@ def load_platform(platforms_runbook: List[schema.Platform]) -> Platform:
def load_platform_from_builder(runbook_builder: RunbookBuilder) -> Platform:
platform_runbook_data = runbook_builder.partial_resolve(constants.PLATFORM)
platform_runbook = schema.Platform.schema().load( # type: ignore
platform_runbook_data, many=True
)
platform_runbook = schema.load_by_type_many(schema.Platform, platform_runbook_data)
platform = load_platform(platform_runbook)
return platform

Просмотреть файл

@ -210,7 +210,7 @@ class RootRunner(Action):
runner_filters: Dict[str, List[schema.BaseTestCaseFilter]] = {}
for raw_filter in runbook.testcase_raw:
# by default run all filtered cases unless 'enable' is specified as false
filter = schema.BaseTestCaseFilter.schema().load(raw_filter) # type:ignore
filter = schema.load_by_type(schema.BaseTestCaseFilter, raw_filter)
if filter.enable:
raw_filters: List[schema.BaseTestCaseFilter] = runner_filters.get(
filter.type, []

Просмотреть файл

@ -608,10 +608,8 @@ class LisaRunner(BaseRunner):
node_requirement_data = deep_update_dict(
platform_requirement_data, node_requirement_data
)
node_requirement = (
schema.NodeSpace.schema().load( # type: ignore
node_requirement_data
)
node_requirement = schema.load_by_type(
schema.NodeSpace, node_requirement_data
)
environment_requirement.nodes[index] = node_requirement

Просмотреть файл

@ -120,8 +120,8 @@ class ExtendableSchemaMixin:
runbook_type=runbook_type, type_name=type_name
)
if self.extended_schemas and type_name in self.extended_schemas:
self._extended_runbook: T = runbook_type.schema().load( # type:ignore
self.extended_schemas[type_name]
self._extended_runbook: T = load_by_type(
runbook_type, self.extended_schemas[type_name]
)
else:
# value may be filled outside, so hold and return an object.
@ -266,7 +266,7 @@ class Extension:
if isinstance(extension, str):
extension = Extension(path=extension)
elif isinstance(extension, dict):
extension = Extension.schema().load(extension) # type: ignore
extension = load_by_type(Extension, extension)
results.append(extension)
return results
@ -324,10 +324,7 @@ class Variable:
if isinstance(self.value_raw, dict):
self.value: Union[
str, bool, int, VariableEntry, List[Union[str, bool, int]]
] = cast(
VariableEntry,
VariableEntry.schema().load(self.value_raw), # type:ignore
)
] = load_by_type(VariableEntry, self.value_raw)
else:
self.value = self.value_raw
@ -676,18 +673,14 @@ class Environment:
for node_raw in self.nodes_raw:
node_type = node_raw[constants.TYPE]
if node_type == constants.ENVIRONMENTS_NODES_REQUIREMENT:
original_req: NodeSpace = NodeSpace.schema().load( # type:ignore
node_raw
)
original_req: NodeSpace = load_by_type(NodeSpace, node_raw)
expanded_req = original_req.expand_by_node_count()
if self.nodes_requirement is None:
self.nodes_requirement = []
self.nodes_requirement.extend(expanded_req)
else:
# load base schema for future parsing
node: Node = Node.schema().load( # type:ignore
node_raw
)
node: Node = load_by_type(Node, node_raw)
results.append(node)
self.nodes_raw = None
@ -764,7 +757,7 @@ class Platform(TypedSchema, ExtendableSchemaMixin):
# But the schema will be validated here. The original NodeSpace object holds
if self.requirement:
# validate schema of raw inputs
Capability.schema().load(self.requirement) # type: ignore
load_by_type(Capability, self.requirement)
@dataclass_json()
@ -916,3 +909,26 @@ class Runbook:
}
]
self.testcase: List[Any] = []
def load_by_type(schema_type: Type[T], raw_runbook: Any, many: bool = False) -> T:
"""
Convert dict, list or base typed schema to specified typed schema.
"""
if type(raw_runbook) == schema_type:
return raw_runbook
if not isinstance(raw_runbook, dict) and not many:
raw_runbook = raw_runbook.to_dict()
result: T = schema_type.schema().load(raw_runbook, many=many) # type: ignore
return result
def load_by_type_many(schema_type: Type[T], raw_runbook: Any) -> List[T]:
"""
Convert raw list to list of typed schema. It has different returned type
with load_by_type.
"""
result = load_by_type(schema_type, raw_runbook=raw_runbook, many=True)
return cast(List[T], result)

Просмотреть файл

@ -132,8 +132,8 @@ class AzureNodeSchema:
self.marketplace_raw = dict(
(k, v.lower()) for k, v in self.marketplace_raw.items()
)
marketplace = AzureVmMarketplaceSchema.schema().load( # type: ignore
self.marketplace_raw
marketplace = schema.load_by_type(
AzureVmMarketplaceSchema, self.marketplace_raw
)
# this step makes marketplace_raw is validated, and
# filter out any unwanted content.
@ -193,11 +193,11 @@ class AzureNodeSchema:
self.shared_gallery_raw = dict(
(k, v.lower()) for k, v in self.shared_gallery_raw.items()
)
shared_gallery = SharedImageGallerySchema.schema().load( # type: ignore
self.shared_gallery_raw
shared_gallery = schema.load_by_type(
SharedImageGallerySchema, self.shared_gallery_raw
)
if not shared_gallery.subscription_id: # type: ignore
shared_gallery.subscription_id = self.subscription_id # type: ignore
if not shared_gallery.subscription_id:
shared_gallery.subscription_id = self.subscription_id
# this step makes shared_gallery_raw is validated, and
# filter out any unwanted content.
self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore

Просмотреть файл

@ -723,9 +723,7 @@ class AzurePlatform(Platform):
try:
with open(cached_file_name, "r") as f:
loaded_data: Dict[str, Any] = json.load(f)
loaded_obj = AzureLocation.schema().load( # type:ignore
loaded_data
)
loaded_obj = schema.load_by_type(AzureLocation, loaded_data)
except Exception as identifier:
# if schema changed, There may be exception, remove cache and retry
# Note: retry on this method depends on decorator

Просмотреть файл

@ -142,8 +142,8 @@ def _run_transformers(
transformers_runbook: List[schema.Transformer] = []
for runbook_data in transformers_data:
# get base transformer runbook for replacing variables.
runbook: schema.Transformer = schema.Transformer.schema().load( # type: ignore
runbook_data
runbook: schema.Transformer = schema.load_by_type(
schema.Transformer, runbook_data
)
if runbook.enabled:
@ -165,7 +165,7 @@ def _run_transformers(
replace_variables(runbook_data, copied_variables)
# revert to runbook
runbook = schema.Transformer.schema().load(runbook_data) # type: ignore
runbook = schema.load_by_type(schema.Transformer, runbook_data)
derived_builder = runbook_builder.derive(copied_variables)
transformer = factory.create_by_runbook(

Просмотреть файл

@ -20,7 +20,7 @@ class BaseClassWithRunbookMixin(BaseClassMixin):
) -> "BaseClassWithRunbookMixin":
if cls.type_schema() != type(runbook):
# reload if type is defined in subclass
runbook = cls.type_schema().schema().load(runbook.to_dict()) # type:ignore
runbook = schema.load_by_type(cls.type_schema(), runbook)
return cls(runbook=runbook, **kwargs)
@classmethod
@ -72,7 +72,7 @@ class Factory(InitializableMixin, Generic[T_BASECLASS], SubClassTypeDict):
raise LisaException(
f"cannot find subclass '{type_name}' of {self._base_type.__name__}"
)
instance = sub_type.schema().load(raw_runbook) # type: ignore
instance: Any = schema.load_by_type(sub_type, raw_runbook)
if hasattr(instance, "extended_schemas"):
if instance.extended_schemas:
raise LisaException(

Просмотреть файл

@ -4,7 +4,7 @@
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union
import yaml
@ -142,10 +142,8 @@ def _load_from_runbook(
current_variables = higher_level_variables.copy()
if constants.VARIABLE in runbook_data:
variable_entries: List[
schema.Variable
] = schema.Variable.schema().load( # type:ignore
runbook_data[constants.VARIABLE], many=True
variable_entries: List[schema.Variable] = schema.load_by_type_many(
schema.Variable, runbook_data[constants.VARIABLE]
)
left_variables = variable_entries.copy()
@ -262,10 +260,7 @@ def load_from_variable_entry(
value = raw_value
else:
if isinstance(raw_value, dict):
raw_value = cast(
schema.VariableEntry,
schema.VariableEntry.schema().load(raw_value), # type: ignore
)
raw_value = schema.load_by_type(schema.VariableEntry, raw_value)
is_secret = is_secret or raw_value.is_secret
mask_pattern_name = raw_value.mask
value = raw_value.value

Просмотреть файл

@ -116,40 +116,46 @@ class RequirementTestCase(SearchSpaceTestCase):
def test_supported_simple_requirement(self) -> None:
n1 = schema.NodeSpace()
n1 = n1.generate_min_capability(n1)
n4 = schema.NodeSpace.schema().load( # type:ignore
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 4}
n4 = schema.load_by_type(
schema.NodeSpace,
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 4},
)
n4 = n4.generate_min_capability(n4)
n4g1 = schema.NodeSpace.schema().load( # type:ignore
n4g1 = schema.load_by_type(
schema.NodeSpace,
{
"type": constants.ENVIRONMENTS_NODES_REQUIREMENT,
"core_count": 4,
"gpu_count": 1,
}
},
)
n4g1 = n4g1.generate_min_capability(n4g1)
n6 = schema.NodeSpace.schema().load( # type:ignore
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6}
n6 = schema.load_by_type(
schema.NodeSpace,
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6},
)
n6 = n6.generate_min_capability(n6)
n6g2 = schema.NodeSpace.schema().load( # type:ignore
n6g2 = schema.load_by_type(
schema.NodeSpace,
{
"type": constants.ENVIRONMENTS_NODES_REQUIREMENT,
"core_count": 6,
"gpu_count": 2,
}
},
)
n6g2 = n6g2.generate_min_capability(n6g2)
n6g1 = schema.NodeSpace.schema().load( # type:ignore
n6g1 = schema.load_by_type(
schema.NodeSpace,
{
"type": constants.ENVIRONMENTS_NODES_REQUIREMENT,
"core_count": 6,
"gpu_count": 1,
}
},
)
n6g1 = n6g1.generate_min_capability(n6g1)
n10 = schema.NodeSpace.schema().load( # type:ignore
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 10}
n10 = schema.load_by_type(
schema.NodeSpace,
{"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 10},
)
n10 = n10.generate_min_capability(n10)

Просмотреть файл

@ -173,7 +173,7 @@ def generate_runbook(
environments.append({"nodes": [n]})
data = {"max_concurrency": 2, constants.ENVIRONMENTS: environments}
return schema.EnvironmentRoot.schema().load(data) # type: ignore
return schema.load_by_type(schema.EnvironmentRoot, data)
class EnvironmentTestCase(TestCase):

Просмотреть файл

@ -120,7 +120,7 @@ def generate_platform(
"admin_password": admin_password,
"admin_private_key_file": admin_key_file,
}
runbook = schema.Platform.schema().load(runbook_data) # type: ignore
runbook = schema.load_by_type(schema.Platform, runbook_data)
platform = load_platform([runbook])
platform.initialize()
try:

Просмотреть файл

@ -137,9 +137,7 @@ class TestTransformerCase(TestCase):
def test_transformer_no_name(self) -> None:
# no name, the type will be used. in name
transformers_data: List[Any] = [{"type": MOCK, "items": {"v0": "v0_1"}}]
transformers = schema.Transformer.schema().load( # type: ignore
transformers_data, many=True
)
transformers = schema.load_by_type_many(schema.Transformer, transformers_data)
runbook_builder = self._generate_runbook_builder(transformers)
result = transformer._run_transformers(runbook_builder)
@ -199,10 +197,8 @@ class TestTransformerCase(TestCase):
items: Dict[str, str] = dict()
for item_index in range(index + 1):
items[f"v{item_index}"] = f"{index}_{item_index}"
runbook: schema.Transformer = (
schema.Transformer.schema().load( # type:ignore
{"type": MOCK, "name": f"t{index}", "items": items}
)
runbook: schema.Transformer = schema.load_by_type(
schema.Transformer, {"type": MOCK, "name": f"t{index}", "items": items}
)
results.append(runbook)