diff --git a/airflow/contrib/hooks/sftp_hook.py b/airflow/contrib/hooks/sftp_hook.py
index 48abfb6174..9dea2ee298 100644
--- a/airflow/contrib/hooks/sftp_hook.py
+++ b/airflow/contrib/hooks/sftp_hook.py
@@ -16,274 +16,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""
+This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.
+"""
-import datetime
-import stat
-from typing import Dict, List, Optional, Tuple
+import warnings
-import pysftp
+# pylint: disable=unused-import
+from airflow.providers.sftp.hooks.sftp_hook import SFTPHook # noqa
-from airflow.contrib.hooks.ssh_hook import SSHHook
-
-
-class SFTPHook(SSHHook):
- """
- This hook is inherited from SSH hook. Please refer to SSH hook for the input
- arguments.
-
- Interact with SFTP. Aims to be interchangeable with FTPHook.
-
- :Pitfalls::
-
- - In contrast with FTPHook describe_directory only returns size, type and
- modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
- unique.
- - retrieve_file and store_file only take a local full path and not a
- buffer.
- - If no mode is passed to create_directory it will be created with 777
- permissions.
-
- Errors that may occur throughout but should be handled downstream.
- """
-
- def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
- kwargs['ssh_conn_id'] = ftp_conn_id
- super().__init__(*args, **kwargs)
-
- self.conn = None
- self.private_key_pass = None
-
- # Fail for unverified hosts, unless this is explicitly allowed
- self.no_host_key_check = False
-
- if self.ssh_conn_id is not None:
- conn = self.get_connection(self.ssh_conn_id)
- if conn.extra is not None:
- extra_options = conn.extra_dejson
- if 'private_key_pass' in extra_options:
- self.private_key_pass = extra_options.get('private_key_pass', None)
-
- # For backward compatibility
- # TODO: remove in Airflow 2.1
- import warnings
- if 'ignore_hostkey_verification' in extra_options:
- warnings.warn(
- 'Extra option `ignore_hostkey_verification` is deprecated.'
- 'Please use `no_host_key_check` instead.'
- 'This option will be removed in Airflow 2.1',
- DeprecationWarning,
- stacklevel=2,
- )
- self.no_host_key_check = str(
- extra_options['ignore_hostkey_verification']
- ).lower() == 'true'
-
- if 'no_host_key_check' in extra_options:
- self.no_host_key_check = str(
- extra_options['no_host_key_check']).lower() == 'true'
-
- if 'private_key' in extra_options:
- warnings.warn(
- 'Extra option `private_key` is deprecated.'
- 'Please use `key_file` instead.'
- 'This option will be removed in Airflow 2.1',
- DeprecationWarning,
- stacklevel=2,
- )
- self.key_file = extra_options.get('private_key')
-
- def get_conn(self) -> pysftp.Connection:
- """
- Returns an SFTP connection object
- """
- if self.conn is None:
- cnopts = pysftp.CnOpts()
- if self.no_host_key_check:
- cnopts.hostkeys = None
- cnopts.compression = self.compress
- conn_params = {
- 'host': self.remote_host,
- 'port': self.port,
- 'username': self.username,
- 'cnopts': cnopts
- }
- if self.password and self.password.strip():
- conn_params['password'] = self.password
- if self.key_file:
- conn_params['private_key'] = self.key_file
- if self.private_key_pass:
- conn_params['private_key_pass'] = self.private_key_pass
-
- self.conn = pysftp.Connection(**conn_params)
- return self.conn
-
- def close_conn(self) -> None:
- """
- Closes the connection. An error will occur if the
- connection wasnt ever opened.
- """
- conn = self.conn
- conn.close() # type: ignore
- self.conn = None
-
- def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
- """
- Returns a dictionary of {filename: {attributes}} for all files
- on the remote system (where the MLSD command is supported).
-
- :param path: full path to the remote directory
- :type path: str
- """
- conn = self.get_conn()
- flist = conn.listdir_attr(path)
- files = {}
- for f in flist:
- modify = datetime.datetime.fromtimestamp(
- f.st_mtime).strftime('%Y%m%d%H%M%S')
- files[f.filename] = {
- 'size': f.st_size,
- 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
- 'modify': modify}
- return files
-
- def list_directory(self, path: str) -> List[str]:
- """
- Returns a list of files on the remote system.
-
- :param path: full path to the remote directory to list
- :type path: str
- """
- conn = self.get_conn()
- files = conn.listdir(path)
- return files
-
- def create_directory(self, path: str, mode: int = 777) -> None:
- """
- Creates a directory on the remote system.
-
- :param path: full path to the remote directory to create
- :type path: str
- :param mode: int representation of octal mode for directory
- """
- conn = self.get_conn()
- conn.makedirs(path, mode)
-
- def delete_directory(self, path: str) -> None:
- """
- Deletes a directory on the remote system.
-
- :param path: full path to the remote directory to delete
- :type path: str
- """
- conn = self.get_conn()
- conn.rmdir(path)
-
- def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
- """
- Transfers the remote file to a local location.
- If local_full_path is a string path, the file will be put
- at that location
-
- :param remote_full_path: full path to the remote file
- :type remote_full_path: str
- :param local_full_path: full path to the local file
- :type local_full_path: str
- """
- conn = self.get_conn()
- self.log.info('Retrieving file from FTP: %s', remote_full_path)
- conn.get(remote_full_path, local_full_path)
- self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
-
- def store_file(self, remote_full_path: str, local_full_path: str) -> None:
- """
- Transfers a local file to the remote location.
- If local_full_path_or_buffer is a string path, the file will be read
- from that location
-
- :param remote_full_path: full path to the remote file
- :type remote_full_path: str
- :param local_full_path: full path to the local file
- :type local_full_path: str
- """
- conn = self.get_conn()
- conn.put(local_full_path, remote_full_path)
-
- def delete_file(self, path: str) -> None:
- """
- Removes a file on the FTP Server
-
- :param path: full path to the remote file
- :type path: str
- """
- conn = self.get_conn()
- conn.remove(path)
-
- def get_mod_time(self, path: str) -> str:
- conn = self.get_conn()
- ftp_mdtm = conn.stat(path).st_mtime
- return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
-
- def path_exists(self, path: str) -> bool:
- """
- Returns True if a remote entity exists
-
- :param path: full path to the remote file or directory
- :type path: str
- """
- conn = self.get_conn()
- return conn.exists(path)
-
- @staticmethod
- def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
- """
- Return True if given path starts with prefix (if set) and ends with delimiter (if set).
-
- :param path: path to be checked
- :type path: str
- :param prefix: if set path will be checked is starting with prefix
- :type prefix: str
- :param delimiter: if set path will be checked is ending with suffix
- :type delimiter: str
- :return: bool
- """
- if prefix is not None and not path.startswith(prefix):
- return False
- if delimiter is not None and not path.endswith(delimiter):
- return False
- return True
-
- def get_tree_map(
- self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
- ) -> Tuple[List[str], List[str], List[str]]:
- """
- Return tuple with recursive lists of files, directories and unknown paths from given path.
- It is possible to filter results by giving prefix and/or delimiter parameters.
-
- :param path: path from which tree will be built
- :type path: str
- :param prefix: if set paths will be added if start with prefix
- :type prefix: str
- :param delimiter: if set paths will be added if end with delimiter
- :type delimiter: str
- :return: tuple with list of files, dirs and unknown items
- :rtype: Tuple[List[str], List[str], List[str]]
- """
- conn = self.get_conn()
- files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
-
- def append_matching_path_callback(list_):
- return (
- lambda item: list_.append(item)
- if self._is_path_match(item, prefix, delimiter)
- else None
- )
-
- conn.walktree(
- remotepath=path,
- fcallback=append_matching_path_callback(files),
- dcallback=append_matching_path_callback(dirs),
- ucallback=append_matching_path_callback(unknowns),
- recurse=True,
- )
-
- return files, dirs, unknowns
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py
index d71e825282..a84f54f44d 100644
--- a/airflow/contrib/operators/sftp_operator.py
+++ b/airflow/contrib/operators/sftp_operator.py
@@ -16,165 +16,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import os
+"""
+This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.
+"""
-from airflow.contrib.hooks.ssh_hook import SSHHook
-from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
-from airflow.utils.decorators import apply_defaults
+import warnings
+# pylint: disable=unused-import
+from airflow.providers.sftp.operators.sftp_operator import SFTPOperator # noqa
-class SFTPOperation:
- PUT = 'put'
- GET = 'get'
-
-
-class SFTPOperator(BaseOperator):
- """
- SFTPOperator for transferring files from remote host to local or vice a versa.
- This operator uses ssh_hook to open sftp transport channel that serve as basis
- for file transfer.
-
- :param ssh_hook: predefined ssh_hook to use for remote execution.
- Either `ssh_hook` or `ssh_conn_id` needs to be provided.
- :type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
- :param ssh_conn_id: connection id from airflow Connections.
- `ssh_conn_id` will be ignored if `ssh_hook` is provided.
- :type ssh_conn_id: str
- :param remote_host: remote host to connect (templated)
- Nullable. If provided, it will replace the `remote_host` which was
- defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
- :type remote_host: str
- :param local_filepath: local file path to get or put. (templated)
- :type local_filepath: str
- :param remote_filepath: remote file path to get or put. (templated)
- :type remote_filepath: str
- :param operation: specify operation 'get' or 'put', defaults to put
- :type operation: str
- :param confirm: specify if the SFTP operation should be confirmed, defaults to True
- :type confirm: bool
- :param create_intermediate_dirs: create missing intermediate directories when
- copying from remote to local and vice-versa. Default is False.
-
- Example: The following task would copy ``file.txt`` to the remote host
- at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
- don't exist. If the parameter is not passed it would error as the directory
- does not exist. ::
-
- put_file = SFTPOperator(
- task_id="test_sftp",
- ssh_conn_id="ssh_default",
- local_filepath="/tmp/file.txt",
- remote_filepath="/tmp/tmp1/tmp2/file.txt",
- operation="put",
- create_intermediate_dirs=True,
- dag=dag
- )
-
- :type create_intermediate_dirs: bool
- """
- template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
-
- @apply_defaults
- def __init__(self,
- ssh_hook=None,
- ssh_conn_id=None,
- remote_host=None,
- local_filepath=None,
- remote_filepath=None,
- operation=SFTPOperation.PUT,
- confirm=True,
- create_intermediate_dirs=False,
- *args,
- **kwargs):
- super().__init__(*args, **kwargs)
- self.ssh_hook = ssh_hook
- self.ssh_conn_id = ssh_conn_id
- self.remote_host = remote_host
- self.local_filepath = local_filepath
- self.remote_filepath = remote_filepath
- self.operation = operation
- self.confirm = confirm
- self.create_intermediate_dirs = create_intermediate_dirs
- if not (self.operation.lower() == SFTPOperation.GET or
- self.operation.lower() == SFTPOperation.PUT):
- raise TypeError("unsupported operation value {0}, expected {1} or {2}"
- .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
-
- def execute(self, context):
- file_msg = None
- try:
- if self.ssh_conn_id:
- if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
- self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
- else:
- self.log.info("ssh_hook is not provided or invalid. " +
- "Trying ssh_conn_id to create SSHHook.")
- self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
-
- if not self.ssh_hook:
- raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
-
- if self.remote_host is not None:
- self.log.info("remote_host is provided explicitly. " +
- "It will replace the remote_host which was defined " +
- "in ssh_hook or predefined in connection of ssh_conn_id.")
- self.ssh_hook.remote_host = self.remote_host
-
- with self.ssh_hook.get_conn() as ssh_client:
- sftp_client = ssh_client.open_sftp()
- if self.operation.lower() == SFTPOperation.GET:
- local_folder = os.path.dirname(self.local_filepath)
- if self.create_intermediate_dirs:
- # Create Intermediate Directories if it doesn't exist
- try:
- os.makedirs(local_folder)
- except OSError:
- if not os.path.isdir(local_folder):
- raise
- file_msg = "from {0} to {1}".format(self.remote_filepath,
- self.local_filepath)
- self.log.info("Starting to transfer %s", file_msg)
- sftp_client.get(self.remote_filepath, self.local_filepath)
- else:
- remote_folder = os.path.dirname(self.remote_filepath)
- if self.create_intermediate_dirs:
- _make_intermediate_dirs(
- sftp_client=sftp_client,
- remote_directory=remote_folder,
- )
- file_msg = "from {0} to {1}".format(self.local_filepath,
- self.remote_filepath)
- self.log.info("Starting to transfer file %s", file_msg)
- sftp_client.put(self.local_filepath,
- self.remote_filepath,
- confirm=self.confirm)
-
- except Exception as e:
- raise AirflowException("Error while transferring {0}, error: {1}"
- .format(file_msg, str(e)))
-
- return self.local_filepath
-
-
-def _make_intermediate_dirs(sftp_client, remote_directory):
- """
- Create all the intermediate directories in a remote host
-
- :param sftp_client: A Paramiko SFTP client.
- :param remote_directory: Absolute Path of the directory containing the file
- :return:
- """
- if remote_directory == '/':
- sftp_client.chdir('/')
- return
- if remote_directory == '':
- return
- try:
- sftp_client.chdir(remote_directory)
- except OSError:
- dirname, basename = os.path.split(remote_directory.rstrip('/'))
- _make_intermediate_dirs(sftp_client, dirname)
- sftp_client.mkdir(basename)
- sftp_client.chdir(basename)
- return
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/contrib/sensors/sftp_sensor.py b/airflow/contrib/sensors/sftp_sensor.py
index f58bc27bd9..241763f9e1 100644
--- a/airflow/contrib/sensors/sftp_sensor.py
+++ b/airflow/contrib/sensors/sftp_sensor.py
@@ -16,40 +16,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""
+This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.
+"""
-from paramiko import SFTP_NO_SUCH_FILE
+import warnings
-from airflow.contrib.hooks.sftp_hook import SFTPHook
-from airflow.sensors.base_sensor_operator import BaseSensorOperator
-from airflow.utils.decorators import apply_defaults
+# pylint: disable=unused-import
+from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor # noqa
-
-class SFTPSensor(BaseSensorOperator):
- """
- Waits for a file or directory to be present on SFTP.
-
- :param path: Remote file or directory path
- :type path: str
- :param sftp_conn_id: The connection to run the sensor against
- :type sftp_conn_id: str
- """
- template_fields = ('path',)
-
- @apply_defaults
- def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.path = path
- self.hook = None
- self.sftp_conn_id = sftp_conn_id
-
- def poke(self, context):
- self.hook = SFTPHook(self.sftp_conn_id)
- self.log.info('Poking for %s', self.path)
- try:
- self.hook.get_mod_time(self.path)
- except OSError as e:
- if e.errno != SFTP_NO_SUCH_FILE:
- raise e
- return False
- self.hook.close_conn()
- return True
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/operators/gcs_to_sftp.py b/airflow/operators/gcs_to_sftp.py
index 3fbcaaf0bb..dd302aafa8 100644
--- a/airflow/operators/gcs_to_sftp.py
+++ b/airflow/operators/gcs_to_sftp.py
@@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile
from typing import Optional
from airflow import AirflowException
-from airflow.contrib.hooks.sftp_hook import SFTPHook
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
from airflow.models import BaseOperator
+from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.utils.decorators import apply_defaults
WILDCARD = "*"
diff --git a/airflow/providers/google/cloud/operators/sftp_to_gcs.py b/airflow/providers/google/cloud/operators/sftp_to_gcs.py
index 478bff1e86..7d6a6b68e6 100644
--- a/airflow/providers/google/cloud/operators/sftp_to_gcs.py
+++ b/airflow/providers/google/cloud/operators/sftp_to_gcs.py
@@ -23,9 +23,9 @@ from tempfile import NamedTemporaryFile
from typing import Optional, Union
from airflow import AirflowException
-from airflow.contrib.hooks.sftp_hook import SFTPHook
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
from airflow.models import BaseOperator
+from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
from airflow.utils.decorators import apply_defaults
WILDCARD = "*"
diff --git a/airflow/providers/sftp/__init__.py b/airflow/providers/sftp/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/sftp/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/sftp/hooks/__init__.py b/airflow/providers/sftp/hooks/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/sftp/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/sftp/hooks/sftp_hook.py b/airflow/providers/sftp/hooks/sftp_hook.py
new file mode 100644
index 0000000000..8526798536
--- /dev/null
+++ b/airflow/providers/sftp/hooks/sftp_hook.py
@@ -0,0 +1,297 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains SFTP hook.
+"""
+import datetime
+import stat
+from typing import Dict, List, Optional, Tuple
+
+import pysftp
+
+from airflow.contrib.hooks.ssh_hook import SSHHook
+
+
+class SFTPHook(SSHHook):
+ """
+ This hook is inherited from SSH hook. Please refer to SSH hook for the input
+ arguments.
+
+ Interact with SFTP. Aims to be interchangeable with FTPHook.
+
+ :Pitfalls::
+
+ - In contrast with FTPHook describe_directory only returns size, type and
+ modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
+ unique.
+ - retrieve_file and store_file only take a local full path and not a
+ buffer.
+ - If no mode is passed to create_directory it will be created with 777
+ permissions.
+
+ Errors that may occur throughout but should be handled downstream.
+ """
+
+ def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
+ kwargs['ssh_conn_id'] = ftp_conn_id
+ super().__init__(*args, **kwargs)
+
+ self.conn = None
+ self.private_key_pass = None
+
+ # Fail for unverified hosts, unless this is explicitly allowed
+ self.no_host_key_check = False
+
+ if self.ssh_conn_id is not None:
+ conn = self.get_connection(self.ssh_conn_id)
+ if conn.extra is not None:
+ extra_options = conn.extra_dejson
+ if 'private_key_pass' in extra_options:
+ self.private_key_pass = extra_options.get('private_key_pass', None)
+
+ # For backward compatibility
+ # TODO: remove in Airflow 2.1
+ import warnings
+ if 'ignore_hostkey_verification' in extra_options:
+ warnings.warn(
+ 'Extra option `ignore_hostkey_verification` is deprecated.'
+ 'Please use `no_host_key_check` instead.'
+ 'This option will be removed in Airflow 2.1',
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.no_host_key_check = str(
+ extra_options['ignore_hostkey_verification']
+ ).lower() == 'true'
+
+ if 'no_host_key_check' in extra_options:
+ self.no_host_key_check = str(
+ extra_options['no_host_key_check']).lower() == 'true'
+
+ if 'private_key' in extra_options:
+ warnings.warn(
+ 'Extra option `private_key` is deprecated.'
+ 'Please use `key_file` instead.'
+ 'This option will be removed in Airflow 2.1',
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ self.key_file = extra_options.get('private_key')
+
+ def get_conn(self) -> pysftp.Connection:
+ """
+ Returns an SFTP connection object
+ """
+ if self.conn is None:
+ cnopts = pysftp.CnOpts()
+ if self.no_host_key_check:
+ cnopts.hostkeys = None
+ cnopts.compression = self.compress
+ conn_params = {
+ 'host': self.remote_host,
+ 'port': self.port,
+ 'username': self.username,
+ 'cnopts': cnopts
+ }
+ if self.password and self.password.strip():
+ conn_params['password'] = self.password
+ if self.key_file:
+ conn_params['private_key'] = self.key_file
+ if self.private_key_pass:
+ conn_params['private_key_pass'] = self.private_key_pass
+
+ self.conn = pysftp.Connection(**conn_params)
+ return self.conn
+
+ def close_conn(self) -> None:
+ """
+ Closes the connection. An error will occur if the
+ connection wasnt ever opened.
+ """
+ conn = self.conn
+ conn.close() # type: ignore
+ self.conn = None
+
+ def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
+ """
+ Returns a dictionary of {filename: {attributes}} for all files
+ on the remote system (where the MLSD command is supported).
+
+ :param path: full path to the remote directory
+ :type path: str
+ """
+ conn = self.get_conn()
+ flist = conn.listdir_attr(path)
+ files = {}
+ for f in flist:
+ modify = datetime.datetime.fromtimestamp(
+ f.st_mtime).strftime('%Y%m%d%H%M%S')
+ files[f.filename] = {
+ 'size': f.st_size,
+ 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
+ 'modify': modify}
+ return files
+
+ def list_directory(self, path: str) -> List[str]:
+ """
+ Returns a list of files on the remote system.
+
+ :param path: full path to the remote directory to list
+ :type path: str
+ """
+ conn = self.get_conn()
+ files = conn.listdir(path)
+ return files
+
+ def create_directory(self, path: str, mode: int = 777) -> None:
+ """
+ Creates a directory on the remote system.
+
+ :param path: full path to the remote directory to create
+ :type path: str
+ :param mode: int representation of octal mode for directory
+ """
+ conn = self.get_conn()
+ conn.makedirs(path, mode)
+
+ def delete_directory(self, path: str) -> None:
+ """
+ Deletes a directory on the remote system.
+
+ :param path: full path to the remote directory to delete
+ :type path: str
+ """
+ conn = self.get_conn()
+ conn.rmdir(path)
+
+ def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
+ """
+ Transfers the remote file to a local location.
+ If local_full_path is a string path, the file will be put
+ at that location
+
+ :param remote_full_path: full path to the remote file
+ :type remote_full_path: str
+ :param local_full_path: full path to the local file
+ :type local_full_path: str
+ """
+ conn = self.get_conn()
+ self.log.info('Retrieving file from FTP: %s', remote_full_path)
+ conn.get(remote_full_path, local_full_path)
+ self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
+
+ def store_file(self, remote_full_path: str, local_full_path: str) -> None:
+ """
+ Transfers a local file to the remote location.
+ If local_full_path_or_buffer is a string path, the file will be read
+ from that location
+
+ :param remote_full_path: full path to the remote file
+ :type remote_full_path: str
+ :param local_full_path: full path to the local file
+ :type local_full_path: str
+ """
+ conn = self.get_conn()
+ conn.put(local_full_path, remote_full_path)
+
+ def delete_file(self, path: str) -> None:
+ """
+ Removes a file on the FTP Server
+
+ :param path: full path to the remote file
+ :type path: str
+ """
+ conn = self.get_conn()
+ conn.remove(path)
+
+ def get_mod_time(self, path: str) -> str:
+ """
+ Returns modification time.
+
+ :param path: full path to the remote file
+ :type path: str
+ """
+ conn = self.get_conn()
+ ftp_mdtm = conn.stat(path).st_mtime
+ return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
+
+ def path_exists(self, path: str) -> bool:
+ """
+ Returns True if a remote entity exists
+
+ :param path: full path to the remote file or directory
+ :type path: str
+ """
+ conn = self.get_conn()
+ return conn.exists(path)
+
+ @staticmethod
+ def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
+ """
+ Return True if given path starts with prefix (if set) and ends with delimiter (if set).
+
+ :param path: path to be checked
+ :type path: str
+ :param prefix: if set path will be checked is starting with prefix
+ :type prefix: str
+ :param delimiter: if set path will be checked is ending with suffix
+ :type delimiter: str
+ :return: bool
+ """
+ if prefix is not None and not path.startswith(prefix):
+ return False
+ if delimiter is not None and not path.endswith(delimiter):
+ return False
+ return True
+
+ def get_tree_map(
+ self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
+ ) -> Tuple[List[str], List[str], List[str]]:
+ """
+ Return tuple with recursive lists of files, directories and unknown paths from given path.
+ It is possible to filter results by giving prefix and/or delimiter parameters.
+
+ :param path: path from which tree will be built
+ :type path: str
+ :param prefix: if set paths will be added if start with prefix
+ :type prefix: str
+ :param delimiter: if set paths will be added if end with delimiter
+ :type delimiter: str
+ :return: tuple with list of files, dirs and unknown items
+ :rtype: Tuple[List[str], List[str], List[str]]
+ """
+ conn = self.get_conn()
+ files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
+
+ def append_matching_path_callback(list_):
+ return (
+ lambda item: list_.append(item)
+ if self._is_path_match(item, prefix, delimiter)
+ else None
+ )
+
+ conn.walktree(
+ remotepath=path,
+ fcallback=append_matching_path_callback(files),
+ dcallback=append_matching_path_callback(dirs),
+ ucallback=append_matching_path_callback(unknowns),
+ recurse=True,
+ )
+
+ return files, dirs, unknowns
diff --git a/airflow/providers/sftp/operators/__init__.py b/airflow/providers/sftp/operators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/sftp/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/sftp/operators/sftp_operator.py b/airflow/providers/sftp/operators/sftp_operator.py
new file mode 100644
index 0000000000..a2753a91a9
--- /dev/null
+++ b/airflow/providers/sftp/operators/sftp_operator.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains SFTP operator.
+"""
+import os
+
+from airflow.contrib.hooks.ssh_hook import SSHHook
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+# pylint: disable=missing-docstring
+class SFTPOperation:
+ PUT = 'put'
+ GET = 'get'
+
+
+class SFTPOperator(BaseOperator):
+ """
+ SFTPOperator for transferring files from remote host to local or vice a versa.
+ This operator uses ssh_hook to open sftp transport channel that serve as basis
+ for file transfer.
+
+ :param ssh_hook: predefined ssh_hook to use for remote execution.
+ Either `ssh_hook` or `ssh_conn_id` needs to be provided.
+ :type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
+ :param ssh_conn_id: connection id from airflow Connections.
+ `ssh_conn_id` will be ignored if `ssh_hook` is provided.
+ :type ssh_conn_id: str
+ :param remote_host: remote host to connect (templated)
+ Nullable. If provided, it will replace the `remote_host` which was
+ defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
+ :type remote_host: str
+ :param local_filepath: local file path to get or put. (templated)
+ :type local_filepath: str
+ :param remote_filepath: remote file path to get or put. (templated)
+ :type remote_filepath: str
+ :param operation: specify operation 'get' or 'put', defaults to put
+ :type operation: str
+ :param confirm: specify if the SFTP operation should be confirmed, defaults to True
+ :type confirm: bool
+ :param create_intermediate_dirs: create missing intermediate directories when
+ copying from remote to local and vice-versa. Default is False.
+
+ Example: The following task would copy ``file.txt`` to the remote host
+ at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
+ don't exist. If the parameter is not passed it would error as the directory
+ does not exist. ::
+
+ put_file = SFTPOperator(
+ task_id="test_sftp",
+ ssh_conn_id="ssh_default",
+ local_filepath="/tmp/file.txt",
+ remote_filepath="/tmp/tmp1/tmp2/file.txt",
+ operation="put",
+ create_intermediate_dirs=True,
+ dag=dag
+ )
+
+ :type create_intermediate_dirs: bool
+ """
+ template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
+
+ @apply_defaults
+ def __init__(self,
+ ssh_hook=None,
+ ssh_conn_id=None,
+ remote_host=None,
+ local_filepath=None,
+ remote_filepath=None,
+ operation=SFTPOperation.PUT,
+ confirm=True,
+ create_intermediate_dirs=False,
+ *args,
+ **kwargs):
+ super().__init__(*args, **kwargs)
+ self.ssh_hook = ssh_hook
+ self.ssh_conn_id = ssh_conn_id
+ self.remote_host = remote_host
+ self.local_filepath = local_filepath
+ self.remote_filepath = remote_filepath
+ self.operation = operation
+ self.confirm = confirm
+ self.create_intermediate_dirs = create_intermediate_dirs
+ if not (self.operation.lower() == SFTPOperation.GET or
+ self.operation.lower() == SFTPOperation.PUT):
+ raise TypeError("unsupported operation value {0}, expected {1} or {2}"
+ .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
+
+ def execute(self, context):
+ file_msg = None
+ try:
+ if self.ssh_conn_id:
+ if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
+ self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
+ else:
+ self.log.info("ssh_hook is not provided or invalid. "
+ "Trying ssh_conn_id to create SSHHook.")
+ self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
+
+ if not self.ssh_hook:
+ raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
+
+ if self.remote_host is not None:
+ self.log.info("remote_host is provided explicitly. "
+ "It will replace the remote_host which was defined "
+ "in ssh_hook or predefined in connection of ssh_conn_id.")
+ self.ssh_hook.remote_host = self.remote_host
+
+ with self.ssh_hook.get_conn() as ssh_client:
+ sftp_client = ssh_client.open_sftp()
+ if self.operation.lower() == SFTPOperation.GET:
+ local_folder = os.path.dirname(self.local_filepath)
+ if self.create_intermediate_dirs:
+ # Create Intermediate Directories if it doesn't exist
+ try:
+ os.makedirs(local_folder)
+ except OSError:
+ if not os.path.isdir(local_folder):
+ raise
+ file_msg = "from {0} to {1}".format(self.remote_filepath,
+ self.local_filepath)
+ self.log.info("Starting to transfer %s", file_msg)
+ sftp_client.get(self.remote_filepath, self.local_filepath)
+ else:
+ remote_folder = os.path.dirname(self.remote_filepath)
+ if self.create_intermediate_dirs:
+ _make_intermediate_dirs(
+ sftp_client=sftp_client,
+ remote_directory=remote_folder,
+ )
+ file_msg = "from {0} to {1}".format(self.local_filepath,
+ self.remote_filepath)
+ self.log.info("Starting to transfer file %s", file_msg)
+ sftp_client.put(self.local_filepath,
+ self.remote_filepath,
+ confirm=self.confirm)
+
+ except Exception as e:
+ raise AirflowException("Error while transferring {0}, error: {1}"
+ .format(file_msg, str(e)))
+
+ return self.local_filepath
+
+
+def _make_intermediate_dirs(sftp_client, remote_directory):
+ """
+ Create all the intermediate directories in a remote host
+
+ :param sftp_client: A Paramiko SFTP client.
+ :param remote_directory: Absolute Path of the directory containing the file
+ :return:
+ """
+ if remote_directory == '/':
+ sftp_client.chdir('/')
+ return
+ if remote_directory == '':
+ return
+ try:
+ sftp_client.chdir(remote_directory)
+ except OSError:
+ dirname, basename = os.path.split(remote_directory.rstrip('/'))
+ _make_intermediate_dirs(sftp_client, dirname)
+ sftp_client.mkdir(basename)
+ sftp_client.chdir(basename)
+ return
diff --git a/airflow/providers/sftp/sensors/__init__.py b/airflow/providers/sftp/sensors/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/sftp/sensors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/sftp/sensors/sftp_sensor.py b/airflow/providers/sftp/sensors/sftp_sensor.py
new file mode 100644
index 0000000000..a3bb0d6ee9
--- /dev/null
+++ b/airflow/providers/sftp/sensors/sftp_sensor.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains SFTP sensor.
+"""
+from paramiko import SFTP_NO_SUCH_FILE
+
+from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class SFTPSensor(BaseSensorOperator):
+ """
+ Waits for a file or directory to be present on SFTP.
+
+ :param path: Remote file or directory path
+ :type path: str
+ :param sftp_conn_id: The connection to run the sensor against
+ :type sftp_conn_id: str
+ """
+ template_fields = ('path',)
+
+ @apply_defaults
+ def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.path = path
+ self.hook = None
+ self.sftp_conn_id = sftp_conn_id
+
+ def poke(self, context):
+ self.hook = SFTPHook(self.sftp_conn_id)
+ self.log.info('Poking for %s', self.path)
+ try:
+ self.hook.get_mod_time(self.path)
+ except OSError as e:
+ if e.errno != SFTP_NO_SUCH_FILE:
+ raise e
+ return False
+ self.hook.close_conn()
+ return True
diff --git a/docs/autoapi_templates/index.rst b/docs/autoapi_templates/index.rst
index 114ee88071..8f39072098 100644
--- a/docs/autoapi_templates/index.rst
+++ b/docs/autoapi_templates/index.rst
@@ -76,6 +76,8 @@ All operators are in the following packages:
airflow/providers/amazon/aws/sensors/index
+ airflow/providers/apache/cassandra/sensors/index
+
airflow/providers/google/cloud/operators/index
airflow/providers/google/cloud/sensors/index
@@ -84,7 +86,9 @@ All operators are in the following packages:
airflow/providers/google/marketing_platform/sensors/index
- airflow/providers/apache/cassandra/sensors/index
+ airflow/providers/sftp/operators/index
+
+ airflow/providers/sftp/sensors/index
Hooks
-----
@@ -117,6 +121,8 @@ All hooks are in the following packages:
airflow/providers/apache/cassandra/hooks/index
+ airflow/providers/sftp/hooks/index
+
Executors
---------
Executors are the mechanism by which task instances get run. All executors are
diff --git a/docs/conf.py b/docs/conf.py
index 96718a2902..90e3857573 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -233,6 +233,7 @@ exclude_patterns = [
'_api/airflow/providers/amazon/aws/example_dags',
'_api/airflow/providers/apache/index.rst',
'_api/airflow/providers/apache/cassandra/index.rst',
+ '_api/airflow/providers/sftp/index.rst',
'_api/enums/index.rst',
'_api/json_schema/index.rst',
'_api/base_serialization/index.rst',
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index 48e8ed723f..50e82c232d 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -1246,9 +1246,9 @@ communication protocols or interface.
* - `SSH File Transfer Protocol (SFTP) `__
-
- - :mod:`airflow.contrib.hooks.sftp_hook`
- - :mod:`airflow.contrib.operators.sftp_operator`
- - :mod:`airflow.contrib.sensors.sftp_sensor`
+ - :mod:`airflow.providers.sftp.hooks.sftp_hook`
+ - :mod:`airflow.providers.sftp.operators.sftp_operator`
+ - :mod:`airflow.providers.sftp.sensors.sftp_sensor`
* - `Secure Shell (SSH) `__
-
diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt
index 2de583bce7..116f44ab15 100644
--- a/scripts/ci/pylint_todo.txt
+++ b/scripts/ci/pylint_todo.txt
@@ -21,7 +21,6 @@
./airflow/contrib/hooks/qubole_check_hook.py
./airflow/contrib/hooks/sagemaker_hook.py
./airflow/contrib/hooks/segment_hook.py
-./airflow/contrib/hooks/sftp_hook.py
./airflow/contrib/hooks/slack_webhook_hook.py
./airflow/contrib/hooks/snowflake_hook.py
./airflow/contrib/hooks/spark_jdbc_hook.py
@@ -65,7 +64,6 @@
./airflow/contrib/operators/sagemaker_transform_operator.py
./airflow/contrib/operators/sagemaker_tuning_operator.py
./airflow/contrib/operators/segment_track_event_operator.py
-./airflow/contrib/operators/sftp_operator.py
./airflow/contrib/operators/sftp_to_s3_operator.py
./airflow/contrib/operators/slack_webhook_operator.py
./airflow/contrib/operators/snowflake_operator.py
@@ -100,7 +98,6 @@
./airflow/contrib/sensors/sagemaker_training_sensor.py
./airflow/contrib/sensors/sagemaker_transform_sensor.py
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
-./airflow/contrib/sensors/sftp_sensor.py
./airflow/contrib/sensors/wasb_sensor.py
./airflow/contrib/sensors/weekday_sensor.py
./airflow/hooks/dbapi_hook.py
diff --git a/tests/contrib/hooks/test_sftp_hook.py b/tests/contrib/hooks/test_sftp_hook.py
index 0204038b79..ed1f250ca8 100644
--- a/tests/contrib/hooks/test_sftp_hook.py
+++ b/tests/contrib/hooks/test_sftp_hook.py
@@ -17,231 +17,17 @@
# specific language governing permissions and limitations
# under the License.
-import os
-import shutil
-import unittest
-from unittest import mock
+import importlib
+from unittest import TestCase
-import pysftp
-from parameterized import parameterized
-
-from airflow.contrib.hooks.sftp_hook import SFTPHook
-from airflow.models import Connection
-from airflow.utils.db import provide_session
-
-TMP_PATH = '/tmp'
-TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
-SUB_DIR = "sub_dir"
-TMP_FILE_FOR_TESTS = 'test_file.txt'
-
-SFTP_CONNECTION_USER = "root"
+OLD_PATH = "airflow.contrib.hooks.sftp_hook"
+NEW_PATH = "airflow.providers.sftp.hooks.sftp_hook"
+WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
-class TestSFTPHook(unittest.TestCase):
-
- @provide_session
- def update_connection(self, login, session=None):
- connection = (session.query(Connection).
- filter(Connection.conn_id == "sftp_default")
- .first())
- old_login = connection.login
- connection.login = login
- session.commit()
- return old_login
-
- def setUp(self):
- self.old_login = self.update_connection(SFTP_CONNECTION_USER)
- self.hook = SFTPHook()
- os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
-
- with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
- file.write('Test file')
- with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
- file.write('Test file')
-
- def test_get_conn(self):
- output = self.hook.get_conn()
- self.assertEqual(type(output), pysftp.Connection)
-
- def test_close_conn(self):
- self.hook.conn = self.hook.get_conn()
- self.assertTrue(self.hook.conn is not None)
- self.hook.close_conn()
- self.assertTrue(self.hook.conn is None)
-
- def test_describe_directory(self):
- output = self.hook.describe_directory(TMP_PATH)
- self.assertTrue(TMP_DIR_FOR_TESTS in output)
-
- def test_list_directory(self):
- output = self.hook.list_directory(
- path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR])
-
- def test_create_and_delete_directory(self):
- new_dir_name = 'new_dir'
- self.hook.create_directory(os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
- output = self.hook.describe_directory(
- os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_name in output)
- self.hook.delete_directory(os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
- output = self.hook.describe_directory(
- os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_name not in output)
-
- def test_create_and_delete_directories(self):
- base_dir = "base_dir"
- sub_dir = "sub_dir"
- new_dir_path = os.path.join(base_dir, sub_dir)
- self.hook.create_directory(os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
- output = self.hook.describe_directory(
- os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(base_dir in output)
- output = self.hook.describe_directory(
- os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
- self.assertTrue(sub_dir in output)
- self.hook.delete_directory(os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
- self.hook.delete_directory(os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
- output = self.hook.describe_directory(
- os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertTrue(new_dir_path not in output)
- self.assertTrue(base_dir not in output)
-
- def test_store_retrieve_and_delete_file(self):
- self.hook.store_file(
- remote_full_path=os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
- )
- output = self.hook.list_directory(
- path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
- retrieved_file_name = 'retrieved.txt'
- self.hook.retrieve_file(
- remote_full_path=os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
- )
- self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
- os.remove(os.path.join(TMP_PATH, retrieved_file_name))
- self.hook.delete_file(path=os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
- output = self.hook.list_directory(
- path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- self.assertEqual(output, [SUB_DIR])
-
- def test_get_mod_time(self):
- self.hook.store_file(
- remote_full_path=os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
- local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
- )
- output = self.hook.get_mod_time(path=os.path.join(
- TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
- self.assertEqual(len(output), 14)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_default(self, get_connection):
- connection = Connection(login='login', host='host')
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_enabled(self, get_connection):
- connection = Connection(
- login='login', host='host',
- extra='{"no_host_key_check": true}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, True)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_disabled(self, get_connection):
- connection = Connection(
- login='login', host='host',
- extra='{"no_host_key_check": false}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
- connection = Connection(
- login='login', host='host',
- extra='{"no_host_key_check": "foo"}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_ignore(self, get_connection):
- connection = Connection(
- login='login', host='host',
- extra='{"ignore_hostkey_verification": true}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, True)
-
- @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
- def test_no_host_key_check_no_ignore(self, get_connection):
- connection = Connection(
- login='login', host='host',
- extra='{"ignore_hostkey_verification": false}')
-
- get_connection.return_value = connection
- hook = SFTPHook()
- self.assertEqual(hook.no_host_key_check, False)
-
- @parameterized.expand([
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
- (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
- (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
- ])
- def test_path_exists(self, path, exists):
- result = self.hook.path_exists(path)
- self.assertEqual(result, exists)
-
- @parameterized.expand([
- ("test/path/file.bin", None, None, True),
- ("test/path/file.bin", "test", None, True),
- ("test/path/file.bin", "test/", None, True),
- ("test/path/file.bin", None, "bin", True),
- ("test/path/file.bin", "test", "bin", True),
- ("test/path/file.bin", "test/", "file.bin", True),
- ("test/path/file.bin", None, "file.bin", True),
- ("test/path/file.bin", "diff", None, False),
- ("test/path/file.bin", "test//", None, False),
- ("test/path/file.bin", None, ".txt", False),
- ("test/path/file.bin", "diff", ".txt", False),
- ])
- def test_path_match(self, path, prefix, delimiter, match):
- result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
- self.assertEqual(result, match)
-
- def test_get_tree_map(self):
- tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- files, dirs, unknowns = tree_map
-
- self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
- self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
- self.assertEqual(unknowns, [])
-
- def tearDown(self):
- shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
- os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
- self.update_connection(self.old_login)
-
-
-if __name__ == '__main__':
- unittest.main()
+class TestMovingSFTPHookToCore(TestCase):
+ def test_move_sftp_hook_to_core(self):
+ with self.assertWarns(DeprecationWarning) as warn:
+ # Reload to see deprecation warning each time
+ importlib.import_module(OLD_PATH)
+ self.assertEqual(WARNING_MESSAGE, str(warn.warning))
diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py
index 4b554eab51..cd2c889760 100644
--- a/tests/contrib/operators/test_sftp_operator.py
+++ b/tests/contrib/operators/test_sftp_operator.py
@@ -17,436 +17,17 @@
# specific language governing permissions and limitations
# under the License.
-import os
-import unittest
-from base64 import b64encode
-from unittest import mock
+import importlib
+from unittest import TestCase
-from airflow import AirflowException
-from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator
-from airflow.contrib.operators.ssh_operator import SSHOperator
-from airflow.models import DAG, TaskInstance
-from airflow.utils import timezone
-from airflow.utils.timezone import datetime
-from tests.test_utils.config import conf_vars
-
-TEST_DAG_ID = 'unit_tests_sftp_op'
-DEFAULT_DATE = datetime(2017, 1, 1)
-TEST_CONN_ID = "conn_id_for_testing"
+OLD_PATH = "airflow.contrib.operators.sftp_operator"
+NEW_PATH = "airflow.providers.sftp.operators.sftp_operator"
+WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
-class TestSFTPOperator(unittest.TestCase):
- def setUp(self):
- from airflow.contrib.hooks.ssh_hook import SSHHook
-
- hook = SSHHook(ssh_conn_id='ssh_default')
- hook.no_host_key_check = True
- args = {
- 'owner': 'airflow',
- 'start_date': DEFAULT_DATE,
- }
- dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
- dag.schedule_interval = '@once'
- self.hook = hook
- self.dag = dag
- self.test_dir = "/tmp"
- self.test_local_dir = "/tmp/tmp2"
- self.test_remote_dir = "/tmp/tmp1"
- self.test_local_filename = 'test_local_file'
- self.test_remote_filename = 'test_remote_file'
- self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
- self.test_local_filename)
- # Local Filepath with Intermediate Directory
- self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
- self.test_local_filename)
- self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
- self.test_remote_filename)
- # Remote Filepath with Intermediate Directory
- self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
- self.test_remote_filename)
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_pickle_file_transfer_put(self):
- test_local_file_content = \
- b"This is local file content \n which is multiline " \
- b"continuing....with other character\nanother line here \n this is last line"
- # create a test file locally
- with open(self.test_local_filepath, 'wb') as file:
- file.write(test_local_file_content)
-
- # put test file to remote
- put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- create_intermediate_dirs=True,
- dag=self.dag
- )
- self.assertIsNotNone(put_test_task)
- ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # check the remote file content
- check_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(check_file_task)
- ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
- ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
- test_local_file_content)
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_file_transfer_no_intermediate_dir_error_put(self):
- test_local_file_content = \
- b"This is local file content \n which is multiline " \
- b"continuing....with other character\nanother line here \n this is last line"
- # create a test file locally
- with open(self.test_local_filepath, 'wb') as file:
- file.write(test_local_file_content)
-
- # Try to put test file to remote
- # This should raise an error with "No such file" as the directory
- # does not exist
- with self.assertRaises(Exception) as error:
- put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath_int_dir,
- operation=SFTPOperation.PUT,
- create_intermediate_dirs=False,
- dag=self.dag
- )
- self.assertIsNotNone(put_test_task)
- ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
- ti2.run()
- self.assertIn('No such file', str(error.exception))
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_file_transfer_with_intermediate_dir_put(self):
- test_local_file_content = \
- b"This is local file content \n which is multiline " \
- b"continuing....with other character\nanother line here \n this is last line"
- # create a test file locally
- with open(self.test_local_filepath, 'wb') as file:
- file.write(test_local_file_content)
-
- # put test file to remote
- put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath_int_dir,
- operation=SFTPOperation.PUT,
- create_intermediate_dirs=True,
- dag=self.dag
- )
- self.assertIsNotNone(put_test_task)
- ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # check the remote file content
- check_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="cat {0}".format(self.test_remote_filepath_int_dir),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(check_file_task)
- ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
- ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
- test_local_file_content)
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'False'})
- def test_json_file_transfer_put(self):
- test_local_file_content = \
- b"This is local file content \n which is multiline " \
- b"continuing....with other character\nanother line here \n this is last line"
- # create a test file locally
- with open(self.test_local_filepath, 'wb') as file:
- file.write(test_local_file_content)
-
- # put test file to remote
- put_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
- )
- self.assertIsNotNone(put_test_task)
- ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # check the remote file content
- check_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="cat {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(check_file_task)
- ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
- ti3.run()
- self.assertEqual(
- ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
- b64encode(test_local_file_content).decode('utf-8'))
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_pickle_file_transfer_get(self):
- test_remote_file_content = \
- "This is remote file content \n which is also multiline " \
- "another line here \n this is last line. EOF"
-
- # create a test file remotely
- create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(create_file_task)
- ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
- ti1.run()
-
- # get remote file to local
- get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- dag=self.dag
- )
- self.assertIsNotNone(get_test_task)
- ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # test the received content
- content_received = None
- with open(self.test_local_filepath, 'r') as file:
- content_received = file.read()
- self.assertEqual(content_received.strip(), test_remote_file_content)
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'False'})
- def test_json_file_transfer_get(self):
- test_remote_file_content = \
- "This is remote file content \n which is also multiline " \
- "another line here \n this is last line. EOF"
-
- # create a test file remotely
- create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(create_file_task)
- ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
- ti1.run()
-
- # get remote file to local
- get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- dag=self.dag
- )
- self.assertIsNotNone(get_test_task)
- ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # test the received content
- content_received = None
- with open(self.test_local_filepath, 'r') as file:
- content_received = file.read()
- self.assertEqual(content_received.strip(),
- test_remote_file_content.encode('utf-8').decode('utf-8'))
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_file_transfer_no_intermediate_dir_error_get(self):
- test_remote_file_content = \
- "This is remote file content \n which is also multiline " \
- "another line here \n this is last line. EOF"
-
- # create a test file remotely
- create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(create_file_task)
- ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
- ti1.run()
-
- # Try to GET test file from remote
- # This should raise an error with "No such file" as the directory
- # does not exist
- with self.assertRaises(Exception) as error:
- get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath_int_dir,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- dag=self.dag
- )
- self.assertIsNotNone(get_test_task)
- ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
- ti2.run()
- self.assertIn('No such file', str(error.exception))
-
- @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
- def test_file_transfer_with_intermediate_dir_error_get(self):
- test_remote_file_content = \
- "This is remote file content \n which is also multiline " \
- "another line here \n this is last line. EOF"
-
- # create a test file remotely
- create_file_task = SSHOperator(
- task_id="test_create_file",
- ssh_hook=self.hook,
- command="echo '{0}' > {1}".format(test_remote_file_content,
- self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(create_file_task)
- ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
- ti1.run()
-
- # get remote file to local
- get_test_task = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- local_filepath=self.test_local_filepath_int_dir,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.GET,
- create_intermediate_dirs=True,
- dag=self.dag
- )
- self.assertIsNotNone(get_test_task)
- ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
- ti2.run()
-
- # test the received content
- content_received = None
- with open(self.test_local_filepath_int_dir, 'r') as file:
- content_received = file.read()
- self.assertEqual(content_received.strip(), test_remote_file_content)
-
- @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
- def test_arg_checking(self):
- # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
- with self.assertRaisesRegex(AirflowException,
- "Cannot operate without ssh_hook or ssh_conn_id."):
- task_0 = SFTPOperator(
- task_id="test_sftp",
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
- )
- task_0.execute(None)
-
- # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
- task_1 = SFTPOperator(
- task_id="test_sftp",
- ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
- ssh_conn_id=TEST_CONN_ID,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
- )
- try:
- task_1.execute(None)
- except Exception:
- pass
- self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
-
- task_2 = SFTPOperator(
- task_id="test_sftp",
- ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
- )
- try:
- task_2.execute(None)
- except Exception:
- pass
- self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
-
- # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
- task_3 = SFTPOperator(
- task_id="test_sftp",
- ssh_hook=self.hook,
- ssh_conn_id=TEST_CONN_ID,
- local_filepath=self.test_local_filepath,
- remote_filepath=self.test_remote_filepath,
- operation=SFTPOperation.PUT,
- dag=self.dag
- )
- try:
- task_3.execute(None)
- except Exception:
- pass
- self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
-
- def delete_local_resource(self):
- if os.path.exists(self.test_local_filepath):
- os.remove(self.test_local_filepath)
- if os.path.exists(self.test_local_filepath_int_dir):
- os.remove(self.test_local_filepath_int_dir)
- if os.path.exists(self.test_local_dir):
- os.rmdir(self.test_local_dir)
-
- def delete_remote_resource(self):
- if os.path.exists(self.test_remote_filepath):
- # check the remote file content
- remove_file_task = SSHOperator(
- task_id="test_check_file",
- ssh_hook=self.hook,
- command="rm {0}".format(self.test_remote_filepath),
- do_xcom_push=True,
- dag=self.dag
- )
- self.assertIsNotNone(remove_file_task)
- ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
- ti3.run()
- if os.path.exists(self.test_remote_filepath_int_dir):
- os.remove(self.test_remote_filepath_int_dir)
- if os.path.exists(self.test_remote_dir):
- os.rmdir(self.test_remote_dir)
-
- def tearDown(self):
- self.delete_local_resource()
- self.delete_remote_resource()
-
-
-if __name__ == '__main__':
- unittest.main()
+class TestMovingSFTPHookToCore(TestCase):
+ def test_move_sftp_operator_to_core(self):
+ with self.assertWarns(DeprecationWarning) as warn:
+ # Reload to see deprecation warning each time
+ importlib.import_module(OLD_PATH)
+ self.assertEqual(WARNING_MESSAGE, str(warn.warning))
diff --git a/tests/contrib/sensors/test_sftp_sensor.py b/tests/contrib/sensors/test_sftp_sensor.py
index 351dac8408..f3f8c2c842 100644
--- a/tests/contrib/sensors/test_sftp_sensor.py
+++ b/tests/contrib/sensors/test_sftp_sensor.py
@@ -17,61 +17,17 @@
# specific language governing permissions and limitations
# under the License.
-import unittest
-from unittest.mock import patch
+import importlib
+from unittest import TestCase
-from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
-
-from airflow.contrib.sensors.sftp_sensor import SFTPSensor
+OLD_PATH = "airflow.contrib.sensors.sftp_sensor"
+NEW_PATH = "airflow.providers.sftp.sensors.sftp_sensor"
+WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
-class TestSFTPSensor(unittest.TestCase):
- @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
- def test_file_present(self, sftp_hook_mock):
- sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
- sftp_sensor = SFTPSensor(
- task_id='unit_test',
- path='/path/to/file/1970-01-01.txt')
- context = {
- 'ds': '1970-01-01'
- }
- output = sftp_sensor.poke(context)
- sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
- '/path/to/file/1970-01-01.txt')
- self.assertTrue(output)
-
- @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
- def test_file_absent(self, sftp_hook_mock):
- sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
- SFTP_NO_SUCH_FILE, 'File missing')
- sftp_sensor = SFTPSensor(
- task_id='unit_test',
- path='/path/to/file/1970-01-01.txt')
- context = {
- 'ds': '1970-01-01'
- }
- output = sftp_sensor.poke(context)
- sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
- '/path/to/file/1970-01-01.txt')
- self.assertFalse(output)
-
- @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
- def test_sftp_failure(self, sftp_hook_mock):
- sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
- SFTP_FAILURE, 'SFTP failure')
- sftp_sensor = SFTPSensor(
- task_id='unit_test',
- path='/path/to/file/1970-01-01.txt')
- context = {
- 'ds': '1970-01-01'
- }
- with self.assertRaises(OSError):
- sftp_sensor.poke(context)
- sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
- '/path/to/file/1970-01-01.txt')
-
- def test_hook_not_created_during_init(self):
- sftp_sensor = SFTPSensor(
- task_id='unit_test',
- path='/path/to/file/1970-01-01.txt')
- self.assertIsNone(sftp_sensor.hook)
+class TestMovingSFTPHookToCore(TestCase):
+ def test_move_sftp_sensor_to_core(self):
+ with self.assertWarns(DeprecationWarning) as warn:
+ # Reload to see deprecation warning each time
+ importlib.import_module(OLD_PATH)
+ self.assertEqual(WARNING_MESSAGE, str(warn.warning))
diff --git a/tests/providers/sftp/__init__.py b/tests/providers/sftp/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/sftp/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/sftp/hooks/__init__.py b/tests/providers/sftp/hooks/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/sftp/hooks/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/sftp/hooks/test_sftp_hook.py b/tests/providers/sftp/hooks/test_sftp_hook.py
new file mode 100644
index 0000000000..ce71994774
--- /dev/null
+++ b/tests/providers/sftp/hooks/test_sftp_hook.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import shutil
+import unittest
+from unittest import mock
+
+import pysftp
+from parameterized import parameterized
+
+from airflow.models import Connection
+from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
+from airflow.utils.db import provide_session
+
+TMP_PATH = '/tmp'
+TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
+SUB_DIR = "sub_dir"
+TMP_FILE_FOR_TESTS = 'test_file.txt'
+
+SFTP_CONNECTION_USER = "root"
+
+
+class TestSFTPHook(unittest.TestCase):
+
+ @provide_session
+ def update_connection(self, login, session=None):
+ connection = (session.query(Connection).
+ filter(Connection.conn_id == "sftp_default")
+ .first())
+ old_login = connection.login
+ connection.login = login
+ session.commit()
+ return old_login
+
+ def setUp(self):
+ self.old_login = self.update_connection(SFTP_CONNECTION_USER)
+ self.hook = SFTPHook()
+ os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
+
+ with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
+ file.write('Test file')
+ with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
+ file.write('Test file')
+
+ def test_get_conn(self):
+ output = self.hook.get_conn()
+ self.assertEqual(type(output), pysftp.Connection)
+
+ def test_close_conn(self):
+ self.hook.conn = self.hook.get_conn()
+ self.assertTrue(self.hook.conn is not None)
+ self.hook.close_conn()
+ self.assertTrue(self.hook.conn is None)
+
+ def test_describe_directory(self):
+ output = self.hook.describe_directory(TMP_PATH)
+ self.assertTrue(TMP_DIR_FOR_TESTS in output)
+
+ def test_list_directory(self):
+ output = self.hook.list_directory(
+ path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertEqual(output, [SUB_DIR])
+
+ def test_create_and_delete_directory(self):
+ new_dir_name = 'new_dir'
+ self.hook.create_directory(os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
+ output = self.hook.describe_directory(
+ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertTrue(new_dir_name in output)
+ self.hook.delete_directory(os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
+ output = self.hook.describe_directory(
+ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertTrue(new_dir_name not in output)
+
+ def test_create_and_delete_directories(self):
+ base_dir = "base_dir"
+ sub_dir = "sub_dir"
+ new_dir_path = os.path.join(base_dir, sub_dir)
+ self.hook.create_directory(os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
+ output = self.hook.describe_directory(
+ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertTrue(base_dir in output)
+ output = self.hook.describe_directory(
+ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
+ self.assertTrue(sub_dir in output)
+ self.hook.delete_directory(os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
+ self.hook.delete_directory(os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
+ output = self.hook.describe_directory(
+ os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertTrue(new_dir_path not in output)
+ self.assertTrue(base_dir not in output)
+
+ def test_store_retrieve_and_delete_file(self):
+ self.hook.store_file(
+ remote_full_path=os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
+ local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
+ )
+ output = self.hook.list_directory(
+ path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
+ retrieved_file_name = 'retrieved.txt'
+ self.hook.retrieve_file(
+ remote_full_path=os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
+ local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
+ )
+ self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
+ os.remove(os.path.join(TMP_PATH, retrieved_file_name))
+ self.hook.delete_file(path=os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
+ output = self.hook.list_directory(
+ path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ self.assertEqual(output, [SUB_DIR])
+
+ def test_get_mod_time(self):
+ self.hook.store_file(
+ remote_full_path=os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
+ local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
+ )
+ output = self.hook.get_mod_time(path=os.path.join(
+ TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
+ self.assertEqual(len(output), 14)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_default(self, get_connection):
+ connection = Connection(login='login', host='host')
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, False)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_enabled(self, get_connection):
+ connection = Connection(
+ login='login', host='host',
+ extra='{"no_host_key_check": true}')
+
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, True)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_disabled(self, get_connection):
+ connection = Connection(
+ login='login', host='host',
+ extra='{"no_host_key_check": false}')
+
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, False)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
+ connection = Connection(
+ login='login', host='host',
+ extra='{"no_host_key_check": "foo"}')
+
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, False)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_ignore(self, get_connection):
+ connection = Connection(
+ login='login', host='host',
+ extra='{"ignore_hostkey_verification": true}')
+
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, True)
+
+ @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
+ def test_no_host_key_check_no_ignore(self, get_connection):
+ connection = Connection(
+ login='login', host='host',
+ extra='{"ignore_hostkey_verification": false}')
+
+ get_connection.return_value = connection
+ hook = SFTPHook()
+ self.assertEqual(hook.no_host_key_check, False)
+
+ @parameterized.expand([
+ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
+ (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
+ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
+ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
+ ])
+ def test_path_exists(self, path, exists):
+ result = self.hook.path_exists(path)
+ self.assertEqual(result, exists)
+
+ @parameterized.expand([
+ ("test/path/file.bin", None, None, True),
+ ("test/path/file.bin", "test", None, True),
+ ("test/path/file.bin", "test/", None, True),
+ ("test/path/file.bin", None, "bin", True),
+ ("test/path/file.bin", "test", "bin", True),
+ ("test/path/file.bin", "test/", "file.bin", True),
+ ("test/path/file.bin", None, "file.bin", True),
+ ("test/path/file.bin", "diff", None, False),
+ ("test/path/file.bin", "test//", None, False),
+ ("test/path/file.bin", None, ".txt", False),
+ ("test/path/file.bin", "diff", ".txt", False),
+ ])
+ def test_path_match(self, path, prefix, delimiter, match):
+ result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
+ self.assertEqual(result, match)
+
+ def test_get_tree_map(self):
+ tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ files, dirs, unknowns = tree_map
+
+ self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
+ self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
+ self.assertEqual(unknowns, [])
+
+ def tearDown(self):
+ shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
+ os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
+ self.update_connection(self.old_login)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/providers/sftp/operators/__init__.py b/tests/providers/sftp/operators/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/sftp/operators/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/sftp/operators/test_sftp_operator.py b/tests/providers/sftp/operators/test_sftp_operator.py
new file mode 100644
index 0000000000..e05d138a27
--- /dev/null
+++ b/tests/providers/sftp/operators/test_sftp_operator.py
@@ -0,0 +1,452 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import unittest
+from base64 import b64encode
+from unittest import mock
+
+from airflow import AirflowException
+from airflow.contrib.operators.ssh_operator import SSHOperator
+from airflow.models import DAG, TaskInstance
+from airflow.providers.sftp.operators.sftp_operator import SFTPOperation, SFTPOperator
+from airflow.utils import timezone
+from airflow.utils.timezone import datetime
+from tests.test_utils.config import conf_vars
+
+TEST_DAG_ID = 'unit_tests_sftp_op'
+DEFAULT_DATE = datetime(2017, 1, 1)
+TEST_CONN_ID = "conn_id_for_testing"
+
+
+class TestSFTPOperator(unittest.TestCase):
+ def setUp(self):
+ from airflow.contrib.hooks.ssh_hook import SSHHook
+
+ hook = SSHHook(ssh_conn_id='ssh_default')
+ hook.no_host_key_check = True
+ args = {
+ 'owner': 'airflow',
+ 'start_date': DEFAULT_DATE,
+ }
+ dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
+ dag.schedule_interval = '@once'
+ self.hook = hook
+ self.dag = dag
+ self.test_dir = "/tmp"
+ self.test_local_dir = "/tmp/tmp2"
+ self.test_remote_dir = "/tmp/tmp1"
+ self.test_local_filename = 'test_local_file'
+ self.test_remote_filename = 'test_remote_file'
+ self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
+ self.test_local_filename)
+ # Local Filepath with Intermediate Directory
+ self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
+ self.test_local_filename)
+ self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
+ self.test_remote_filename)
+ # Remote Filepath with Intermediate Directory
+ self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
+ self.test_remote_filename)
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_pickle_file_transfer_put(self):
+ test_local_file_content = \
+ b"This is local file content \n which is multiline " \
+ b"continuing....with other character\nanother line here \n this is last line"
+ # create a test file locally
+ with open(self.test_local_filepath, 'wb') as file:
+ file.write(test_local_file_content)
+
+ # put test file to remote
+ put_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ create_intermediate_dirs=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(put_test_task)
+ ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # check the remote file content
+ check_file_task = SSHOperator(
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="cat {0}".format(self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(check_file_task)
+ ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
+ ti3.run()
+ self.assertEqual(
+ ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
+ test_local_file_content)
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_file_transfer_no_intermediate_dir_error_put(self):
+ test_local_file_content = \
+ b"This is local file content \n which is multiline " \
+ b"continuing....with other character\nanother line here \n this is last line"
+ # create a test file locally
+ with open(self.test_local_filepath, 'wb') as file:
+ file.write(test_local_file_content)
+
+ # Try to put test file to remote
+ # This should raise an error with "No such file" as the directory
+ # does not exist
+ with self.assertRaises(Exception) as error:
+ put_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath_int_dir,
+ operation=SFTPOperation.PUT,
+ create_intermediate_dirs=False,
+ dag=self.dag
+ )
+ self.assertIsNotNone(put_test_task)
+ ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+ self.assertIn('No such file', str(error.exception))
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_file_transfer_with_intermediate_dir_put(self):
+ test_local_file_content = \
+ b"This is local file content \n which is multiline " \
+ b"continuing....with other character\nanother line here \n this is last line"
+ # create a test file locally
+ with open(self.test_local_filepath, 'wb') as file:
+ file.write(test_local_file_content)
+
+ # put test file to remote
+ put_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath_int_dir,
+ operation=SFTPOperation.PUT,
+ create_intermediate_dirs=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(put_test_task)
+ ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # check the remote file content
+ check_file_task = SSHOperator(
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="cat {0}".format(self.test_remote_filepath_int_dir),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(check_file_task)
+ ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
+ ti3.run()
+ self.assertEqual(
+ ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
+ test_local_file_content)
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'False'})
+ def test_json_file_transfer_put(self):
+ test_local_file_content = \
+ b"This is local file content \n which is multiline " \
+ b"continuing....with other character\nanother line here \n this is last line"
+ # create a test file locally
+ with open(self.test_local_filepath, 'wb') as file:
+ file.write(test_local_file_content)
+
+ # put test file to remote
+ put_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
+ )
+ self.assertIsNotNone(put_test_task)
+ ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # check the remote file content
+ check_file_task = SSHOperator(
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="cat {0}".format(self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(check_file_task)
+ ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
+ ti3.run()
+ self.assertEqual(
+ ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
+ b64encode(test_local_file_content).decode('utf-8'))
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_pickle_file_transfer_get(self):
+ test_remote_file_content = \
+ "This is remote file content \n which is also multiline " \
+ "another line here \n this is last line. EOF"
+
+ # create a test file remotely
+ create_file_task = SSHOperator(
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(create_file_task)
+ ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
+ ti1.run()
+
+ # get remote file to local
+ get_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ dag=self.dag
+ )
+ self.assertIsNotNone(get_test_task)
+ ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # test the received content
+ content_received = None
+ with open(self.test_local_filepath, 'r') as file:
+ content_received = file.read()
+ self.assertEqual(content_received.strip(), test_remote_file_content)
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'False'})
+ def test_json_file_transfer_get(self):
+ test_remote_file_content = \
+ "This is remote file content \n which is also multiline " \
+ "another line here \n this is last line. EOF"
+
+ # create a test file remotely
+ create_file_task = SSHOperator(
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(create_file_task)
+ ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
+ ti1.run()
+
+ # get remote file to local
+ get_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ dag=self.dag
+ )
+ self.assertIsNotNone(get_test_task)
+ ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # test the received content
+ content_received = None
+ with open(self.test_local_filepath, 'r') as file:
+ content_received = file.read()
+ self.assertEqual(content_received.strip(),
+ test_remote_file_content.encode('utf-8').decode('utf-8'))
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_file_transfer_no_intermediate_dir_error_get(self):
+ test_remote_file_content = \
+ "This is remote file content \n which is also multiline " \
+ "another line here \n this is last line. EOF"
+
+ # create a test file remotely
+ create_file_task = SSHOperator(
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(create_file_task)
+ ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
+ ti1.run()
+
+ # Try to GET test file from remote
+ # This should raise an error with "No such file" as the directory
+ # does not exist
+ with self.assertRaises(Exception) as error:
+ get_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath_int_dir,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ dag=self.dag
+ )
+ self.assertIsNotNone(get_test_task)
+ ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+ self.assertIn('No such file', str(error.exception))
+
+ @conf_vars({('core', 'enable_xcom_pickling'): 'True'})
+ def test_file_transfer_with_intermediate_dir_error_get(self):
+ test_remote_file_content = \
+ "This is remote file content \n which is also multiline " \
+ "another line here \n this is last line. EOF"
+
+ # create a test file remotely
+ create_file_task = SSHOperator(
+ task_id="test_create_file",
+ ssh_hook=self.hook,
+ command="echo '{0}' > {1}".format(test_remote_file_content,
+ self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(create_file_task)
+ ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
+ ti1.run()
+
+ # get remote file to local
+ get_test_task = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ local_filepath=self.test_local_filepath_int_dir,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.GET,
+ create_intermediate_dirs=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(get_test_task)
+ ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
+ ti2.run()
+
+ # test the received content
+ content_received = None
+ with open(self.test_local_filepath_int_dir, 'r') as file:
+ content_received = file.read()
+ self.assertEqual(content_received.strip(), test_remote_file_content)
+
+ @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
+ def test_arg_checking(self):
+ # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
+ with self.assertRaisesRegex(AirflowException,
+ "Cannot operate without ssh_hook or ssh_conn_id."):
+ task_0 = SFTPOperator(
+ task_id="test_sftp",
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
+ )
+ task_0.execute(None)
+
+ # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
+ task_1 = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
+ ssh_conn_id=TEST_CONN_ID,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
+ )
+ try:
+ task_1.execute(None)
+ except Exception: # pylint: disable=broad-except
+ pass
+ self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+
+ task_2 = SFTPOperator(
+ task_id="test_sftp",
+ ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
+ )
+ try:
+ task_2.execute(None)
+ except Exception: # pylint: disable=broad-except
+ pass
+ self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
+
+ # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
+ task_3 = SFTPOperator(
+ task_id="test_sftp",
+ ssh_hook=self.hook,
+ ssh_conn_id=TEST_CONN_ID,
+ local_filepath=self.test_local_filepath,
+ remote_filepath=self.test_remote_filepath,
+ operation=SFTPOperation.PUT,
+ dag=self.dag
+ )
+ try:
+ task_3.execute(None)
+ except Exception: # pylint: disable=broad-except
+ pass
+ self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
+
+ def delete_local_resource(self):
+ if os.path.exists(self.test_local_filepath):
+ os.remove(self.test_local_filepath)
+ if os.path.exists(self.test_local_filepath_int_dir):
+ os.remove(self.test_local_filepath_int_dir)
+ if os.path.exists(self.test_local_dir):
+ os.rmdir(self.test_local_dir)
+
+ def delete_remote_resource(self):
+ if os.path.exists(self.test_remote_filepath):
+ # check the remote file content
+ remove_file_task = SSHOperator(
+ task_id="test_check_file",
+ ssh_hook=self.hook,
+ command="rm {0}".format(self.test_remote_filepath),
+ do_xcom_push=True,
+ dag=self.dag
+ )
+ self.assertIsNotNone(remove_file_task)
+ ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
+ ti3.run()
+ if os.path.exists(self.test_remote_filepath_int_dir):
+ os.remove(self.test_remote_filepath_int_dir)
+ if os.path.exists(self.test_remote_dir):
+ os.rmdir(self.test_remote_dir)
+
+ def tearDown(self):
+ self.delete_local_resource()
+ self.delete_remote_resource()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/providers/sftp/sensors/__init__.py b/tests/providers/sftp/sensors/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/sftp/sensors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/sftp/sensors/test_sftp_sensor.py b/tests/providers/sftp/sensors/test_sftp_sensor.py
new file mode 100644
index 0000000000..163f72073a
--- /dev/null
+++ b/tests/providers/sftp/sensors/test_sftp_sensor.py
@@ -0,0 +1,77 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+from unittest.mock import patch
+
+from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
+
+from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor
+
+
+class TestSFTPSensor(unittest.TestCase):
+ @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
+ def test_file_present(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
+ sftp_sensor = SFTPSensor(
+ task_id='unit_test',
+ path='/path/to/file/1970-01-01.txt')
+ context = {
+ 'ds': '1970-01-01'
+ }
+ output = sftp_sensor.poke(context)
+ sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
+ '/path/to/file/1970-01-01.txt')
+ self.assertTrue(output)
+
+ @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
+ def test_file_absent(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
+ SFTP_NO_SUCH_FILE, 'File missing')
+ sftp_sensor = SFTPSensor(
+ task_id='unit_test',
+ path='/path/to/file/1970-01-01.txt')
+ context = {
+ 'ds': '1970-01-01'
+ }
+ output = sftp_sensor.poke(context)
+ sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
+ '/path/to/file/1970-01-01.txt')
+ self.assertFalse(output)
+
+ @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
+ def test_sftp_failure(self, sftp_hook_mock):
+ sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
+ SFTP_FAILURE, 'SFTP failure')
+ sftp_sensor = SFTPSensor(
+ task_id='unit_test',
+ path='/path/to/file/1970-01-01.txt')
+ context = {
+ 'ds': '1970-01-01'
+ }
+ with self.assertRaises(OSError):
+ sftp_sensor.poke(context)
+ sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
+ '/path/to/file/1970-01-01.txt')
+
+ def test_hook_not_created_during_init(self):
+ sftp_sensor = SFTPSensor(
+ task_id='unit_test',
+ path='/path/to/file/1970-01-01.txt')
+ self.assertIsNone(sftp_sensor.hook)