diff --git a/airflow/providers/jira/hooks/jira.py b/airflow/providers/jira/hooks/jira.py index b36c65cac6..3afc9ae9dc 100644 --- a/airflow/providers/jira/hooks/jira.py +++ b/airflow/providers/jira/hooks/jira.py @@ -16,6 +16,8 @@ # specific language governing permissions and limitations # under the License. """Hook for JIRA""" +from typing import Any, Optional + from jira import JIRA from jira.exceptions import JIRAError @@ -31,15 +33,15 @@ class JiraHook(BaseHook): :type jira_conn_id: str """ def __init__(self, - jira_conn_id='jira_default', - proxies=None): + jira_conn_id: str = 'jira_default', + proxies: Optional[Any] = None) -> None: super().__init__() self.jira_conn_id = jira_conn_id self.proxies = proxies self.client = None self.get_conn() - def get_conn(self): + def get_conn(self) -> JIRA: if not self.client: self.log.debug('Creating Jira client for conn_id: %s', self.jira_conn_id) diff --git a/airflow/providers/jira/operators/jira.py b/airflow/providers/jira/operators/jira.py index 07775af9e4..352fe1a219 100644 --- a/airflow/providers/jira/operators/jira.py +++ b/airflow/providers/jira/operators/jira.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +from typing import Any, Callable, Dict, Optional from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -45,13 +46,13 @@ class JiraOperator(BaseOperator): @apply_defaults def __init__(self, - jira_conn_id='jira_default', - jira_method=None, - jira_method_args=None, - result_processor=None, - get_jira_resource_method=None, + jira_method: str, + jira_conn_id: str = 'jira_default', + jira_method_args: Optional[dict] = None, + result_processor: Optional[Callable] = None, + get_jira_resource_method: Optional[Callable] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.jira_conn_id = jira_conn_id self.method_name = jira_method @@ -59,7 +60,7 @@ class JiraOperator(BaseOperator): self.result_processor = result_processor self.get_jira_resource_method = get_jira_resource_method - def execute(self, context): + def execute(self, context: Dict) -> Any: try: if self.get_jira_resource_method is not None: # if get_jira_resource_method is provided, jira_method will be executed on diff --git a/airflow/providers/jira/sensors/jira.py b/airflow/providers/jira/sensors/jira.py index e136e5be34..00d673b6c8 100644 --- a/airflow/providers/jira/sensors/jira.py +++ b/airflow/providers/jira/sensors/jira.py @@ -15,7 +15,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from jira.resources import Resource +from typing import Any, Callable, Dict, Optional + +from jira.resources import Issue, Resource from airflow.providers.jira.operators.jira import JIRAError, JiraOperator from airflow.sensors.base_sensor_operator import BaseSensorOperator @@ -38,12 +40,12 @@ class JiraSensor(BaseSensorOperator): @apply_defaults def __init__(self, - jira_conn_id='jira_default', - method_name=None, - method_params=None, - result_processor=None, + method_name: str, + jira_conn_id: str = 'jira_default', + method_params: Optional[dict] = None, + result_processor: Optional[Callable] = None, *args, - **kwargs): + **kwargs) -> None: super().__init__(*args, **kwargs) self.jira_conn_id = jira_conn_id self.result_processor = None @@ -57,7 +59,7 @@ class JiraSensor(BaseSensorOperator): jira_method_args=self.method_params, result_processor=self.result_processor) - def poke(self, context): + def poke(self, context: Dict) -> Any: return self.jira_operator.execute(context=context) @@ -81,13 +83,12 @@ class JiraTicketSensor(JiraSensor): @apply_defaults def __init__(self, - jira_conn_id='jira_default', - ticket_id=None, - field=None, - expected_value=None, - field_checker_func=None, - *args, - **kwargs): + jira_conn_id: str = 'jira_default', + ticket_id: Optional[str] = None, + field: Optional[str] = None, + expected_value: Optional[str] = None, + field_checker_func: Optional[Callable] = None, + **kwargs) -> None: self.jira_conn_id = jira_conn_id self.ticket_id = ticket_id @@ -98,10 +99,9 @@ class JiraTicketSensor(JiraSensor): super().__init__(jira_conn_id=jira_conn_id, result_processor=field_checker_func, - *args, **kwargs) - def poke(self, context): + def poke(self, context: Dict) -> Any: self.log.info('Jira Sensor checking for change in ticket: %s', self.ticket_id) self.jira_operator.method_name = "issue" @@ -111,7 +111,7 @@ class JiraTicketSensor(JiraSensor): } return JiraSensor.poke(self, context=context) - def issue_field_checker(self, issue): + def issue_field_checker(self, issue: Issue) -> Optional[bool]: """Check issue using different conditions to prepare to evaluate sensor.""" result = None try: # pylint: disable=too-many-nested-blocks diff --git a/tests/providers/jira/sensors/test_jira.py b/tests/providers/jira/sensors/test_jira.py index 59a3c0c630..782eabbd87 100644 --- a/tests/providers/jira/sensors/test_jira.py +++ b/tests/providers/jira/sensors/test_jira.py @@ -64,6 +64,7 @@ class TestJiraSensor(unittest.TestCase): jira_mock.return_value.issue.return_value = minimal_test_ticket ticket_label_sensor = JiraTicketSensor( + method_name='issue', task_id='search-ticket-test', ticket_id='TEST-1226', field_checker_func=TestJiraSensor.field_checker_func,