diff --git a/lisa/commands.py b/lisa/commands.py index 50812bcd1..fd7f45f80 100644 --- a/lisa/commands.py +++ b/lisa/commands.py @@ -24,9 +24,7 @@ def run(args: Namespace) -> int: notifier_data = builder.partial_resolve(constants.NOTIFIER) if notifier_data: - notifier_runbook = schema.Notifier.schema().load( # type: ignore - notifier_data, many=True - ) + notifier_runbook = schema.load_by_type_many(schema.Notifier, notifier_data) notifier.initialize(runbooks=notifier_runbook) run_message = notifier.TestRunMessage( test_project=builder.partial_resolve(constants.TEST_PROJECT), diff --git a/lisa/parameter_parser/runbook.py b/lisa/parameter_parser/runbook.py index f7ddc92f5..55b2b11b5 100644 --- a/lisa/parameter_parser/runbook.py +++ b/lisa/parameter_parser/runbook.py @@ -226,7 +226,7 @@ class RunbookBuilder: and data_from_include[constants.VARIABLE] ): variables_from_include = [ - schema.Variable.schema().load(variable) # type: ignore + schema.load_by_type(schema.Variable, variable) for variable in data_from_include[constants.VARIABLE] ] # resolve to absolute path @@ -238,7 +238,7 @@ class RunbookBuilder: and data_from_current[constants.VARIABLE] ): variables_from_current: List[schema.Variable] = [ - schema.Variable.schema().load(variable) # type: ignore + schema.load_by_type(schema.Variable, variable) for variable in data_from_current[constants.VARIABLE] ] @@ -345,7 +345,7 @@ class RunbookBuilder: for include_raw in includes: try: include: schema.Include - include = schema.Include.schema().load(include_raw) # type: ignore + include = schema.load_by_type(schema.Include, include_raw) except Exception as identifer: raise LisaException( f"error on loading include node [{include_raw}]: {identifer}" diff --git a/lisa/platform_.py b/lisa/platform_.py index de680237c..d84b5e075 100644 --- a/lisa/platform_.py +++ b/lisa/platform_.py @@ -195,8 +195,6 @@ def load_platform(platforms_runbook: List[schema.Platform]) -> Platform: def load_platform_from_builder(runbook_builder: RunbookBuilder) -> Platform: platform_runbook_data = runbook_builder.partial_resolve(constants.PLATFORM) - platform_runbook = schema.Platform.schema().load( # type: ignore - platform_runbook_data, many=True - ) + platform_runbook = schema.load_by_type_many(schema.Platform, platform_runbook_data) platform = load_platform(platform_runbook) return platform diff --git a/lisa/runner.py b/lisa/runner.py index 2b7563a3a..af5b0e0fb 100644 --- a/lisa/runner.py +++ b/lisa/runner.py @@ -210,7 +210,7 @@ class RootRunner(Action): runner_filters: Dict[str, List[schema.BaseTestCaseFilter]] = {} for raw_filter in runbook.testcase_raw: # by default run all filtered cases unless 'enable' is specified as false - filter = schema.BaseTestCaseFilter.schema().load(raw_filter) # type:ignore + filter = schema.load_by_type(schema.BaseTestCaseFilter, raw_filter) if filter.enable: raw_filters: List[schema.BaseTestCaseFilter] = runner_filters.get( filter.type, [] diff --git a/lisa/runners/lisa_runner.py b/lisa/runners/lisa_runner.py index 8e1f35090..e2e12c1d1 100644 --- a/lisa/runners/lisa_runner.py +++ b/lisa/runners/lisa_runner.py @@ -608,10 +608,8 @@ class LisaRunner(BaseRunner): node_requirement_data = deep_update_dict( platform_requirement_data, node_requirement_data ) - node_requirement = ( - schema.NodeSpace.schema().load( # type: ignore - node_requirement_data - ) + node_requirement = schema.load_by_type( + schema.NodeSpace, node_requirement_data ) environment_requirement.nodes[index] = node_requirement diff --git a/lisa/schema.py b/lisa/schema.py index 8d83a8dbe..f9879e661 100644 --- a/lisa/schema.py +++ b/lisa/schema.py @@ -120,8 +120,8 @@ class ExtendableSchemaMixin: runbook_type=runbook_type, type_name=type_name ) if self.extended_schemas and type_name in self.extended_schemas: - self._extended_runbook: T = runbook_type.schema().load( # type:ignore - self.extended_schemas[type_name] + self._extended_runbook: T = load_by_type( + runbook_type, self.extended_schemas[type_name] ) else: # value may be filled outside, so hold and return an object. @@ -266,7 +266,7 @@ class Extension: if isinstance(extension, str): extension = Extension(path=extension) elif isinstance(extension, dict): - extension = Extension.schema().load(extension) # type: ignore + extension = load_by_type(Extension, extension) results.append(extension) return results @@ -324,10 +324,7 @@ class Variable: if isinstance(self.value_raw, dict): self.value: Union[ str, bool, int, VariableEntry, List[Union[str, bool, int]] - ] = cast( - VariableEntry, - VariableEntry.schema().load(self.value_raw), # type:ignore - ) + ] = load_by_type(VariableEntry, self.value_raw) else: self.value = self.value_raw @@ -676,18 +673,14 @@ class Environment: for node_raw in self.nodes_raw: node_type = node_raw[constants.TYPE] if node_type == constants.ENVIRONMENTS_NODES_REQUIREMENT: - original_req: NodeSpace = NodeSpace.schema().load( # type:ignore - node_raw - ) + original_req: NodeSpace = load_by_type(NodeSpace, node_raw) expanded_req = original_req.expand_by_node_count() if self.nodes_requirement is None: self.nodes_requirement = [] self.nodes_requirement.extend(expanded_req) else: # load base schema for future parsing - node: Node = Node.schema().load( # type:ignore - node_raw - ) + node: Node = load_by_type(Node, node_raw) results.append(node) self.nodes_raw = None @@ -764,7 +757,7 @@ class Platform(TypedSchema, ExtendableSchemaMixin): # But the schema will be validated here. The original NodeSpace object holds if self.requirement: # validate schema of raw inputs - Capability.schema().load(self.requirement) # type: ignore + load_by_type(Capability, self.requirement) @dataclass_json() @@ -916,3 +909,26 @@ class Runbook: } ] self.testcase: List[Any] = [] + + +def load_by_type(schema_type: Type[T], raw_runbook: Any, many: bool = False) -> T: + """ + Convert dict, list or base typed schema to specified typed schema. + """ + if type(raw_runbook) == schema_type: + return raw_runbook + + if not isinstance(raw_runbook, dict) and not many: + raw_runbook = raw_runbook.to_dict() + + result: T = schema_type.schema().load(raw_runbook, many=many) # type: ignore + return result + + +def load_by_type_many(schema_type: Type[T], raw_runbook: Any) -> List[T]: + """ + Convert raw list to list of typed schema. It has different returned type + with load_by_type. + """ + result = load_by_type(schema_type, raw_runbook=raw_runbook, many=True) + return cast(List[T], result) diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index 365eb86e6..4fc568257 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -132,8 +132,8 @@ class AzureNodeSchema: self.marketplace_raw = dict( (k, v.lower()) for k, v in self.marketplace_raw.items() ) - marketplace = AzureVmMarketplaceSchema.schema().load( # type: ignore - self.marketplace_raw + marketplace = schema.load_by_type( + AzureVmMarketplaceSchema, self.marketplace_raw ) # this step makes marketplace_raw is validated, and # filter out any unwanted content. @@ -193,11 +193,11 @@ class AzureNodeSchema: self.shared_gallery_raw = dict( (k, v.lower()) for k, v in self.shared_gallery_raw.items() ) - shared_gallery = SharedImageGallerySchema.schema().load( # type: ignore - self.shared_gallery_raw + shared_gallery = schema.load_by_type( + SharedImageGallerySchema, self.shared_gallery_raw ) - if not shared_gallery.subscription_id: # type: ignore - shared_gallery.subscription_id = self.subscription_id # type: ignore + if not shared_gallery.subscription_id: + shared_gallery.subscription_id = self.subscription_id # this step makes shared_gallery_raw is validated, and # filter out any unwanted content. self.shared_gallery_raw = shared_gallery.to_dict() # type: ignore diff --git a/lisa/sut_orchestrator/azure/platform_.py b/lisa/sut_orchestrator/azure/platform_.py index 61cf2a690..0a62fd671 100644 --- a/lisa/sut_orchestrator/azure/platform_.py +++ b/lisa/sut_orchestrator/azure/platform_.py @@ -723,9 +723,7 @@ class AzurePlatform(Platform): try: with open(cached_file_name, "r") as f: loaded_data: Dict[str, Any] = json.load(f) - loaded_obj = AzureLocation.schema().load( # type:ignore - loaded_data - ) + loaded_obj = schema.load_by_type(AzureLocation, loaded_data) except Exception as identifier: # if schema changed, There may be exception, remove cache and retry # Note: retry on this method depends on decorator diff --git a/lisa/transformer.py b/lisa/transformer.py index 594b7dfcc..6db8f7382 100644 --- a/lisa/transformer.py +++ b/lisa/transformer.py @@ -142,8 +142,8 @@ def _run_transformers( transformers_runbook: List[schema.Transformer] = [] for runbook_data in transformers_data: # get base transformer runbook for replacing variables. - runbook: schema.Transformer = schema.Transformer.schema().load( # type: ignore - runbook_data + runbook: schema.Transformer = schema.load_by_type( + schema.Transformer, runbook_data ) if runbook.enabled: @@ -165,7 +165,7 @@ def _run_transformers( replace_variables(runbook_data, copied_variables) # revert to runbook - runbook = schema.Transformer.schema().load(runbook_data) # type: ignore + runbook = schema.load_by_type(schema.Transformer, runbook_data) derived_builder = runbook_builder.derive(copied_variables) transformer = factory.create_by_runbook( diff --git a/lisa/util/subclasses.py b/lisa/util/subclasses.py index 2ebd3f872..d77199eea 100644 --- a/lisa/util/subclasses.py +++ b/lisa/util/subclasses.py @@ -20,7 +20,7 @@ class BaseClassWithRunbookMixin(BaseClassMixin): ) -> "BaseClassWithRunbookMixin": if cls.type_schema() != type(runbook): # reload if type is defined in subclass - runbook = cls.type_schema().schema().load(runbook.to_dict()) # type:ignore + runbook = schema.load_by_type(cls.type_schema(), runbook) return cls(runbook=runbook, **kwargs) @classmethod @@ -72,7 +72,7 @@ class Factory(InitializableMixin, Generic[T_BASECLASS], SubClassTypeDict): raise LisaException( f"cannot find subclass '{type_name}' of {self._base_type.__name__}" ) - instance = sub_type.schema().load(raw_runbook) # type: ignore + instance: Any = schema.load_by_type(sub_type, raw_runbook) if hasattr(instance, "extended_schemas"): if instance.extended_schemas: raise LisaException( diff --git a/lisa/variable.py b/lisa/variable.py index cf6a2face..7da05a00a 100644 --- a/lisa/variable.py +++ b/lisa/variable.py @@ -4,7 +4,7 @@ import os import re from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Union import yaml @@ -142,10 +142,8 @@ def _load_from_runbook( current_variables = higher_level_variables.copy() if constants.VARIABLE in runbook_data: - variable_entries: List[ - schema.Variable - ] = schema.Variable.schema().load( # type:ignore - runbook_data[constants.VARIABLE], many=True + variable_entries: List[schema.Variable] = schema.load_by_type_many( + schema.Variable, runbook_data[constants.VARIABLE] ) left_variables = variable_entries.copy() @@ -262,10 +260,7 @@ def load_from_variable_entry( value = raw_value else: if isinstance(raw_value, dict): - raw_value = cast( - schema.VariableEntry, - schema.VariableEntry.schema().load(raw_value), # type: ignore - ) + raw_value = schema.load_by_type(schema.VariableEntry, raw_value) is_secret = is_secret or raw_value.is_secret mask_pattern_name = raw_value.mask value = raw_value.value diff --git a/selftests/test_env_requirement.py b/selftests/test_env_requirement.py index 7e8da98ed..3a70a8316 100644 --- a/selftests/test_env_requirement.py +++ b/selftests/test_env_requirement.py @@ -116,40 +116,46 @@ class RequirementTestCase(SearchSpaceTestCase): def test_supported_simple_requirement(self) -> None: n1 = schema.NodeSpace() n1 = n1.generate_min_capability(n1) - n4 = schema.NodeSpace.schema().load( # type:ignore - {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 4} + n4 = schema.load_by_type( + schema.NodeSpace, + {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 4}, ) n4 = n4.generate_min_capability(n4) - n4g1 = schema.NodeSpace.schema().load( # type:ignore + n4g1 = schema.load_by_type( + schema.NodeSpace, { "type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 4, "gpu_count": 1, - } + }, ) n4g1 = n4g1.generate_min_capability(n4g1) - n6 = schema.NodeSpace.schema().load( # type:ignore - {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6} + n6 = schema.load_by_type( + schema.NodeSpace, + {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6}, ) n6 = n6.generate_min_capability(n6) - n6g2 = schema.NodeSpace.schema().load( # type:ignore + n6g2 = schema.load_by_type( + schema.NodeSpace, { "type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6, "gpu_count": 2, - } + }, ) n6g2 = n6g2.generate_min_capability(n6g2) - n6g1 = schema.NodeSpace.schema().load( # type:ignore + n6g1 = schema.load_by_type( + schema.NodeSpace, { "type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 6, "gpu_count": 1, - } + }, ) n6g1 = n6g1.generate_min_capability(n6g1) - n10 = schema.NodeSpace.schema().load( # type:ignore - {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 10} + n10 = schema.load_by_type( + schema.NodeSpace, + {"type": constants.ENVIRONMENTS_NODES_REQUIREMENT, "core_count": 10}, ) n10 = n10.generate_min_capability(n10) diff --git a/selftests/test_environment.py b/selftests/test_environment.py index ac4a4e574..459aa185e 100644 --- a/selftests/test_environment.py +++ b/selftests/test_environment.py @@ -173,7 +173,7 @@ def generate_runbook( environments.append({"nodes": [n]}) data = {"max_concurrency": 2, constants.ENVIRONMENTS: environments} - return schema.EnvironmentRoot.schema().load(data) # type: ignore + return schema.load_by_type(schema.EnvironmentRoot, data) class EnvironmentTestCase(TestCase): diff --git a/selftests/test_platform.py b/selftests/test_platform.py index 83f523fe2..dd17d27e1 100644 --- a/selftests/test_platform.py +++ b/selftests/test_platform.py @@ -120,7 +120,7 @@ def generate_platform( "admin_password": admin_password, "admin_private_key_file": admin_key_file, } - runbook = schema.Platform.schema().load(runbook_data) # type: ignore + runbook = schema.load_by_type(schema.Platform, runbook_data) platform = load_platform([runbook]) platform.initialize() try: diff --git a/selftests/test_transformer.py b/selftests/test_transformer.py index d2ed6bac2..fac0dbb09 100644 --- a/selftests/test_transformer.py +++ b/selftests/test_transformer.py @@ -137,9 +137,7 @@ class TestTransformerCase(TestCase): def test_transformer_no_name(self) -> None: # no name, the type will be used. in name transformers_data: List[Any] = [{"type": MOCK, "items": {"v0": "v0_1"}}] - transformers = schema.Transformer.schema().load( # type: ignore - transformers_data, many=True - ) + transformers = schema.load_by_type_many(schema.Transformer, transformers_data) runbook_builder = self._generate_runbook_builder(transformers) result = transformer._run_transformers(runbook_builder) @@ -199,10 +197,8 @@ class TestTransformerCase(TestCase): items: Dict[str, str] = dict() for item_index in range(index + 1): items[f"v{item_index}"] = f"{index}_{item_index}" - runbook: schema.Transformer = ( - schema.Transformer.schema().load( # type:ignore - {"type": MOCK, "name": f"t{index}", "items": items} - ) + runbook: schema.Transformer = schema.load_by_type( + schema.Transformer, {"type": MOCK, "name": f"t{index}", "items": items} ) results.append(runbook)