Sort lists, sets and tuples in Serialized DAGs (#14909)
Currently we check if the dag changed or not via dag_hash. The problem is since the insertion order is not guaranteed, it produces a different hash and hence results in a DB write unncessarily. This commit fixes it.
This commit is contained in:
Родитель
70f184add5
Коммит
4531168e90
|
@ -214,7 +214,7 @@ class BaseSerialization:
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT)
|
return cls._encode({str(k): cls._serialize(v) for k, v in var.items()}, type_=DAT.DICT)
|
||||||
elif isinstance(var, list):
|
elif isinstance(var, list):
|
||||||
return [cls._serialize(v) for v in var]
|
return sorted(cls._serialize(v) for v in var)
|
||||||
elif HAS_KUBERNETES and isinstance(var, k8s.V1Pod):
|
elif HAS_KUBERNETES and isinstance(var, k8s.V1Pod):
|
||||||
json_pod = PodGenerator.serialize_pod(var)
|
json_pod = PodGenerator.serialize_pod(var)
|
||||||
return cls._encode(json_pod, type_=DAT.POD)
|
return cls._encode(json_pod, type_=DAT.POD)
|
||||||
|
@ -240,10 +240,10 @@ class BaseSerialization:
|
||||||
return str(get_python_source(var))
|
return str(get_python_source(var))
|
||||||
elif isinstance(var, set):
|
elif isinstance(var, set):
|
||||||
# FIXME: casts set to list in customized serialization in future.
|
# FIXME: casts set to list in customized serialization in future.
|
||||||
return cls._encode([cls._serialize(v) for v in var], type_=DAT.SET)
|
return cls._encode(sorted(cls._serialize(v) for v in var), type_=DAT.SET)
|
||||||
elif isinstance(var, tuple):
|
elif isinstance(var, tuple):
|
||||||
# FIXME: casts tuple to list in customized serialization in future.
|
# FIXME: casts tuple to list in customized serialization in future.
|
||||||
return cls._encode([cls._serialize(v) for v in var], type_=DAT.TUPLE)
|
return cls._encode(sorted(cls._serialize(v) for v in var), type_=DAT.TUPLE)
|
||||||
elif isinstance(var, TaskGroup):
|
elif isinstance(var, TaskGroup):
|
||||||
return SerializedTaskGroup.serialize_task_group(var)
|
return SerializedTaskGroup.serialize_task_group(var)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -353,9 +353,10 @@ class TestStringifiedDAGs(unittest.TestCase):
|
||||||
"_task_group",
|
"_task_group",
|
||||||
}
|
}
|
||||||
for field in fields_to_check:
|
for field in fields_to_check:
|
||||||
assert getattr(serialized_dag, field) == getattr(
|
dag_field = getattr(dag, field)
|
||||||
dag, field
|
if isinstance(dag_field, list):
|
||||||
), f'{dag.dag_id}.{field} does not match'
|
dag_field = sorted(dag_field)
|
||||||
|
assert getattr(serialized_dag, field) == dag_field, f'{dag.dag_id}.{field} does not match'
|
||||||
|
|
||||||
if dag.default_args:
|
if dag.default_args:
|
||||||
for k, v in dag.default_args.items():
|
for k, v in dag.default_args.items():
|
||||||
|
@ -1027,6 +1028,33 @@ class TestStringifiedDAGs(unittest.TestCase):
|
||||||
|
|
||||||
assert deserialized_dag.has_on_failure_callback is expected_value
|
assert deserialized_dag.has_on_failure_callback is expected_value
|
||||||
|
|
||||||
|
@parameterized.expand(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
['task_1', 'task_5', 'task_2', 'task_4'],
|
||||||
|
['task_1', 'task_2', 'task_4', 'task_5'],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{'task_1', 'task_5', 'task_2', 'task_4'},
|
||||||
|
['task_1', 'task_2', 'task_4', 'task_5'],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
('task_1', 'task_5', 'task_2', 'task_4'),
|
||||||
|
['task_1', 'task_2', 'task_4', 'task_5'],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"task3": "test3", "task2": "test2", "task1": "test1"},
|
||||||
|
{"task1": "test1", "task2": "test2", "task3": "test3"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_serialized_objects_are_sorted(self, object_to_serialized, expected_output):
|
||||||
|
"""Test Serialized Lists, Sets and Tuples are sorted"""
|
||||||
|
serialized_obj = SerializedDAG._serialize(object_to_serialized)
|
||||||
|
if isinstance(serialized_obj, dict) and "__type" in serialized_obj:
|
||||||
|
serialized_obj = serialized_obj["__var"]
|
||||||
|
assert serialized_obj == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_kubernetes_optional():
|
def test_kubernetes_optional():
|
||||||
"""Serialisation / deserialisation continues to work without kubernetes installed"""
|
"""Serialisation / deserialisation continues to work without kubernetes installed"""
|
||||||
|
|
Загрузка…
Ссылка в новой задаче