diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index f2724f3c63..28781aa41a 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -81,7 +81,7 @@ class AWSAthenaOperator(BaseOperator): def get_hook(self): """Create and return an AWSAthenaHook.""" - return AWSAthenaHook(self.aws_conn_id, self.sleep_time) + return AWSAthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) def execute(self, context): """ diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index 21672a811f..109c983462 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -62,7 +62,7 @@ class TestAWSAthenaOperator(unittest.TestCase): self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE', database='TEST_DATABASE', output_location='s3://test_s3_bucket/', client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', - sleep_time=1, max_tries=3, dag=self.dag) + sleep_time=0, max_tries=3, dag=self.dag) def test_init(self): self.assertEqual(self.athena.task_id, MOCK_DATA['task_id']) @@ -70,7 +70,10 @@ class TestAWSAthenaOperator(unittest.TestCase): self.assertEqual(self.athena.database, MOCK_DATA['database']) self.assertEqual(self.athena.aws_conn_id, 'aws_default') self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token']) - self.assertEqual(self.athena.sleep_time, 1) + self.assertEqual(self.athena.sleep_time, 0) + + hook = self.athena.get_hook() + self.assertEqual(hook.sleep_time, 0) @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) @mock.patch.object(AWSAthenaHook, 'run_query', return_value=ATHENA_QUERY_ID)