diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py index 2e890f463e..747ad04ff0 100644 --- a/airflow/contrib/operators/ssh_operator.py +++ b/airflow/contrib/operators/ssh_operator.py @@ -69,16 +69,17 @@ class SSHOperator(BaseOperator): def execute(self, context): try: if self.ssh_conn_id and not self.ssh_hook: - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id, + timeout=self.timeout) if not self.ssh_hook: - raise AirflowException("can not operate without ssh_hook or ssh_conn_id") + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") if self.remote_host is not None: self.ssh_hook.remote_host = self.remote_host if not self.command: - raise AirflowException("no command specified so nothing to execute here.") + raise AirflowException("SSH command not specified. Aborting.") with self.ssh_hook.get_conn() as ssh_client: # Auto apply tty when its required in case of sudo diff --git a/tests/contrib/operators/test_ssh_operator.py b/tests/contrib/operators/test_ssh_operator.py index b97ba84a01..7ddd24b2ac 100644 --- a/tests/contrib/operators/test_ssh_operator.py +++ b/tests/contrib/operators/test_ssh_operator.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -58,6 +58,23 @@ class SSHOperatorTest(unittest.TestCase): self.hook = hook self.dag = dag + def test_hook_created_correctly(self): + TIMEOUT = 20 + SSH_ID = "ssh_default" + task = SSHOperator( + task_id="test", + command="echo -n airflow", + dag=self.dag, + timeout=TIMEOUT, + ssh_conn_id="ssh_default" + ) + self.assertIsNotNone(task) + + task.execute(None) + + self.assertEquals(TIMEOUT, task.ssh_hook.timeout) + self.assertEquals(SSH_ID, task.ssh_hook.ssh_conn_id) + def test_json_command_execution(self): configuration.conf.set("core", "enable_xcom_pickling", "False") task = SSHOperator(