limit eol of files to LF, and fix platform bug

This commit is contained in:
Chi Song 2020-08-26 13:09:47 +08:00
Родитель 145877ff63
Коммит 9ffefdc866
3 изменённых файлов: 62 добавлений и 41 удалений

1
.gitattributes поставляемый Normal file
Просмотреть файл

@ -0,0 +1 @@
* text=auto

Просмотреть файл

@ -61,13 +61,15 @@ def _set_schema_class(
def _load_platform_schema() -> None:
# load platform extensions
platform_fields: List[Tuple[str, Any, Any]] = []
node_fields: List[Tuple[str, Any, Any]] = []
platform_field_names: List[str] = []
# 1. discover extension schemas and construct new field
for platform in platforms.values():
platform_type_name = platform.platform_type()
platform_schema = platform.platform_schema
if platform_schema:
platform_type_name = platform.platform_type()
platform_field = (
platform_type_name,
Optional[platform_schema],
@ -75,44 +77,54 @@ def _load_platform_schema() -> None:
)
platform_fields.append(platform_field)
platform_field_names.append(platform_type_name)
node_schema = platform.node_schema
if node_schema:
node_field = (
platform_type_name,
Optional[node_schema],
field(default=None),
)
node_fields.append(node_field)
# 2. refresh data class in schema platform and environment
if platform_fields:
# add in platform type
platform_with_type_fields = platform_fields.copy()
platform_field_names.append(constants.PLATFORM_READY)
type_field = (
constants.TYPE,
str,
field(
default=constants.PLATFORM_READY,
metadata=schema.metadata(
required=True,
validate=marshmallow_validate.OneOf(platform_field_names),
if platform_fields or node_fields:
if platform_fields:
# add in platform type
platform_with_type_fields = platform_fields.copy()
platform_field_names.append(constants.PLATFORM_READY)
type_field = (
constants.TYPE,
str,
field(
default=constants.PLATFORM_READY,
metadata=schema.metadata(
required=True,
validate=marshmallow_validate.OneOf(platform_field_names),
),
),
),
)
platform_with_type_fields.append(type_field)
# refresh platform
_set_schema_class(schema.Platform, platform_with_type_fields)
schema.Platform.supported_types = platform_field_names
)
platform_with_type_fields.append(type_field)
# refresh platform
_set_schema_class(schema.Platform, platform_with_type_fields)
schema.Platform.supported_types = platform_field_names
# refresh node spec, template, and chain dataclasses
_set_schema_class(schema.NodeSpec, platform_fields)
_set_schema_class(schema.Template, platform_fields)
if node_fields:
# refresh node spec, template, and chain dataclasses
_set_schema_class(schema.NodeSpec, node_fields)
_set_schema_class(schema.Template, node_fields)
template_in_config = (
constants.ENVIRONMENTS_TEMPLATE,
Optional[schema.Template],
field(default=None),
)
_set_schema_class(schema.Environment, [template_in_config])
platform_spec_in_config = (
constants.ENVIRONMENTS,
Optional[List[schema.Environment]],
field(default=None),
)
_set_schema_class(schema.EnvironmentRoot, [platform_spec_in_config])
template_in_config = (
constants.ENVIRONMENTS_TEMPLATE,
Optional[schema.Template],
field(default=None),
)
_set_schema_class(schema.Environment, [template_in_config])
platform_spec_in_config = (
constants.ENVIRONMENTS,
Optional[List[schema.Environment]],
field(default=None),
)
_set_schema_class(schema.EnvironmentRoot, [platform_spec_in_config])
platform_in_config = (
constants.PLATFORM,

Просмотреть файл

@ -39,19 +39,27 @@ T = TypeVar("T", bound=DataClassJsonMixin)
class ExtendableSchemaMixin:
def get_extended_schema(
self, schema_type: Type[T], schema_name: str = constants.TYPE
) -> T:
def get_extended_schema(self, schema_type: Type[T], field_name: str = "") -> T:
"""
schema_type: type of schema
field_name: the field name which stores the data, if it's "", get it from type
"""
assert issubclass(
schema_type, DataClassJsonMixin
), "schema_type must annotate from DataClassJsonMixin"
assert hasattr(self, schema_name), f"cannot find attr '{schema_name}'"
if not field_name:
assert hasattr(self, constants.TYPE), (
f"cannot find type attr on '{schema_type.__name__}'."
f"either set field_name or make sure type attr exists."
)
field_name = getattr(self, constants.TYPE)
assert hasattr(self, field_name), f"cannot find attr '{field_name}'"
customized_config = getattr(self, schema_name)
customized_config = getattr(self, field_name)
if not isinstance(customized_config, schema_type):
raise LisaException(
f"schema type mismatch, expected type: {schema_type}"
f"data: {customized_config}"
f"schema type mismatch, expected type: {schema_type} "
f"data type: {type(customized_config)}"
)
return customized_config