improve type hinting for celery provider (#9762)
This commit is contained in:
Родитель
a6b04d7b9a
Коммит
5bb228d841
|
@ -16,6 +16,8 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from celery.app import control
|
||||
|
||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||
|
@ -36,16 +38,16 @@ class CeleryQueueSensor(BaseSensorOperator):
|
|||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
celery_queue,
|
||||
target_task_id=None,
|
||||
celery_queue: str,
|
||||
target_task_id: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs):
|
||||
**kwargs) -> None:
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.celery_queue = celery_queue
|
||||
self.target_task_id = target_task_id
|
||||
|
||||
def _check_task_id(self, context):
|
||||
def _check_task_id(self, context: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Gets the returned Celery result from the Airflow task
|
||||
ID provided to the sensor, and returns True if the
|
||||
|
@ -60,7 +62,7 @@ class CeleryQueueSensor(BaseSensorOperator):
|
|||
celery_result = ti.xcom_pull(task_ids=self.target_task_id)
|
||||
return celery_result.ready()
|
||||
|
||||
def poke(self, context):
|
||||
def poke(self, context: Dict[str, Any]) -> bool:
|
||||
|
||||
if self.target_task_id:
|
||||
return self._check_task_id(context)
|
||||
|
|
Загрузка…
Ссылка в новой задаче