diff --git a/airflow/contrib/hooks/sftp_hook.py b/airflow/contrib/hooks/sftp_hook.py index 48abfb6174..9dea2ee298 100644 --- a/airflow/contrib/hooks/sftp_hook.py +++ b/airflow/contrib/hooks/sftp_hook.py @@ -16,274 +16,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`. +""" -import datetime -import stat -from typing import Dict, List, Optional, Tuple +import warnings -import pysftp +# pylint: disable=unused-import +from airflow.providers.sftp.hooks.sftp_hook import SFTPHook # noqa -from airflow.contrib.hooks.ssh_hook import SSHHook - - -class SFTPHook(SSHHook): - """ - This hook is inherited from SSH hook. Please refer to SSH hook for the input - arguments. - - Interact with SFTP. Aims to be interchangeable with FTPHook. - - :Pitfalls:: - - - In contrast with FTPHook describe_directory only returns size, type and - modify. It doesn't return unix.owner, unix.mode, perm, unix.group and - unique. - - retrieve_file and store_file only take a local full path and not a - buffer. - - If no mode is passed to create_directory it will be created with 777 - permissions. - - Errors that may occur throughout but should be handled downstream. - """ - - def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: - kwargs['ssh_conn_id'] = ftp_conn_id - super().__init__(*args, **kwargs) - - self.conn = None - self.private_key_pass = None - - # Fail for unverified hosts, unless this is explicitly allowed - self.no_host_key_check = False - - if self.ssh_conn_id is not None: - conn = self.get_connection(self.ssh_conn_id) - if conn.extra is not None: - extra_options = conn.extra_dejson - if 'private_key_pass' in extra_options: - self.private_key_pass = extra_options.get('private_key_pass', None) - - # For backward compatibility - # TODO: remove in Airflow 2.1 - import warnings - if 'ignore_hostkey_verification' in extra_options: - warnings.warn( - 'Extra option `ignore_hostkey_verification` is deprecated.' - 'Please use `no_host_key_check` instead.' - 'This option will be removed in Airflow 2.1', - DeprecationWarning, - stacklevel=2, - ) - self.no_host_key_check = str( - extra_options['ignore_hostkey_verification'] - ).lower() == 'true' - - if 'no_host_key_check' in extra_options: - self.no_host_key_check = str( - extra_options['no_host_key_check']).lower() == 'true' - - if 'private_key' in extra_options: - warnings.warn( - 'Extra option `private_key` is deprecated.' - 'Please use `key_file` instead.' - 'This option will be removed in Airflow 2.1', - DeprecationWarning, - stacklevel=2, - ) - self.key_file = extra_options.get('private_key') - - def get_conn(self) -> pysftp.Connection: - """ - Returns an SFTP connection object - """ - if self.conn is None: - cnopts = pysftp.CnOpts() - if self.no_host_key_check: - cnopts.hostkeys = None - cnopts.compression = self.compress - conn_params = { - 'host': self.remote_host, - 'port': self.port, - 'username': self.username, - 'cnopts': cnopts - } - if self.password and self.password.strip(): - conn_params['password'] = self.password - if self.key_file: - conn_params['private_key'] = self.key_file - if self.private_key_pass: - conn_params['private_key_pass'] = self.private_key_pass - - self.conn = pysftp.Connection(**conn_params) - return self.conn - - def close_conn(self) -> None: - """ - Closes the connection. An error will occur if the - connection wasnt ever opened. - """ - conn = self.conn - conn.close() # type: ignore - self.conn = None - - def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: - """ - Returns a dictionary of {filename: {attributes}} for all files - on the remote system (where the MLSD command is supported). - - :param path: full path to the remote directory - :type path: str - """ - conn = self.get_conn() - flist = conn.listdir_attr(path) - files = {} - for f in flist: - modify = datetime.datetime.fromtimestamp( - f.st_mtime).strftime('%Y%m%d%H%M%S') - files[f.filename] = { - 'size': f.st_size, - 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', - 'modify': modify} - return files - - def list_directory(self, path: str) -> List[str]: - """ - Returns a list of files on the remote system. - - :param path: full path to the remote directory to list - :type path: str - """ - conn = self.get_conn() - files = conn.listdir(path) - return files - - def create_directory(self, path: str, mode: int = 777) -> None: - """ - Creates a directory on the remote system. - - :param path: full path to the remote directory to create - :type path: str - :param mode: int representation of octal mode for directory - """ - conn = self.get_conn() - conn.makedirs(path, mode) - - def delete_directory(self, path: str) -> None: - """ - Deletes a directory on the remote system. - - :param path: full path to the remote directory to delete - :type path: str - """ - conn = self.get_conn() - conn.rmdir(path) - - def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: - """ - Transfers the remote file to a local location. - If local_full_path is a string path, the file will be put - at that location - - :param remote_full_path: full path to the remote file - :type remote_full_path: str - :param local_full_path: full path to the local file - :type local_full_path: str - """ - conn = self.get_conn() - self.log.info('Retrieving file from FTP: %s', remote_full_path) - conn.get(remote_full_path, local_full_path) - self.log.info('Finished retrieving file from FTP: %s', remote_full_path) - - def store_file(self, remote_full_path: str, local_full_path: str) -> None: - """ - Transfers a local file to the remote location. - If local_full_path_or_buffer is a string path, the file will be read - from that location - - :param remote_full_path: full path to the remote file - :type remote_full_path: str - :param local_full_path: full path to the local file - :type local_full_path: str - """ - conn = self.get_conn() - conn.put(local_full_path, remote_full_path) - - def delete_file(self, path: str) -> None: - """ - Removes a file on the FTP Server - - :param path: full path to the remote file - :type path: str - """ - conn = self.get_conn() - conn.remove(path) - - def get_mod_time(self, path: str) -> str: - conn = self.get_conn() - ftp_mdtm = conn.stat(path).st_mtime - return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S') - - def path_exists(self, path: str) -> bool: - """ - Returns True if a remote entity exists - - :param path: full path to the remote file or directory - :type path: str - """ - conn = self.get_conn() - return conn.exists(path) - - @staticmethod - def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool: - """ - Return True if given path starts with prefix (if set) and ends with delimiter (if set). - - :param path: path to be checked - :type path: str - :param prefix: if set path will be checked is starting with prefix - :type prefix: str - :param delimiter: if set path will be checked is ending with suffix - :type delimiter: str - :return: bool - """ - if prefix is not None and not path.startswith(prefix): - return False - if delimiter is not None and not path.endswith(delimiter): - return False - return True - - def get_tree_map( - self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None - ) -> Tuple[List[str], List[str], List[str]]: - """ - Return tuple with recursive lists of files, directories and unknown paths from given path. - It is possible to filter results by giving prefix and/or delimiter parameters. - - :param path: path from which tree will be built - :type path: str - :param prefix: if set paths will be added if start with prefix - :type prefix: str - :param delimiter: if set paths will be added if end with delimiter - :type delimiter: str - :return: tuple with list of files, dirs and unknown items - :rtype: Tuple[List[str], List[str], List[str]] - """ - conn = self.get_conn() - files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str] - - def append_matching_path_callback(list_): - return ( - lambda item: list_.append(item) - if self._is_path_match(item, prefix, delimiter) - else None - ) - - conn.walktree( - remotepath=path, - fcallback=append_matching_path_callback(files), - dcallback=append_matching_path_callback(dirs), - ucallback=append_matching_path_callback(unknowns), - recurse=True, - ) - - return files, dirs, unknowns +warnings.warn( + "This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py index d71e825282..a84f54f44d 100644 --- a/airflow/contrib/operators/sftp_operator.py +++ b/airflow/contrib/operators/sftp_operator.py @@ -16,165 +16,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import os +""" +This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`. +""" -from airflow.contrib.hooks.ssh_hook import SSHHook -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.utils.decorators import apply_defaults +import warnings +# pylint: disable=unused-import +from airflow.providers.sftp.operators.sftp_operator import SFTPOperator # noqa -class SFTPOperation: - PUT = 'put' - GET = 'get' - - -class SFTPOperator(BaseOperator): - """ - SFTPOperator for transferring files from remote host to local or vice a versa. - This operator uses ssh_hook to open sftp transport channel that serve as basis - for file transfer. - - :param ssh_hook: predefined ssh_hook to use for remote execution. - Either `ssh_hook` or `ssh_conn_id` needs to be provided. - :type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook - :param ssh_conn_id: connection id from airflow Connections. - `ssh_conn_id` will be ignored if `ssh_hook` is provided. - :type ssh_conn_id: str - :param remote_host: remote host to connect (templated) - Nullable. If provided, it will replace the `remote_host` which was - defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. - :type remote_host: str - :param local_filepath: local file path to get or put. (templated) - :type local_filepath: str - :param remote_filepath: remote file path to get or put. (templated) - :type remote_filepath: str - :param operation: specify operation 'get' or 'put', defaults to put - :type operation: str - :param confirm: specify if the SFTP operation should be confirmed, defaults to True - :type confirm: bool - :param create_intermediate_dirs: create missing intermediate directories when - copying from remote to local and vice-versa. Default is False. - - Example: The following task would copy ``file.txt`` to the remote host - at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they - don't exist. If the parameter is not passed it would error as the directory - does not exist. :: - - put_file = SFTPOperator( - task_id="test_sftp", - ssh_conn_id="ssh_default", - local_filepath="/tmp/file.txt", - remote_filepath="/tmp/tmp1/tmp2/file.txt", - operation="put", - create_intermediate_dirs=True, - dag=dag - ) - - :type create_intermediate_dirs: bool - """ - template_fields = ('local_filepath', 'remote_filepath', 'remote_host') - - @apply_defaults - def __init__(self, - ssh_hook=None, - ssh_conn_id=None, - remote_host=None, - local_filepath=None, - remote_filepath=None, - operation=SFTPOperation.PUT, - confirm=True, - create_intermediate_dirs=False, - *args, - **kwargs): - super().__init__(*args, **kwargs) - self.ssh_hook = ssh_hook - self.ssh_conn_id = ssh_conn_id - self.remote_host = remote_host - self.local_filepath = local_filepath - self.remote_filepath = remote_filepath - self.operation = operation - self.confirm = confirm - self.create_intermediate_dirs = create_intermediate_dirs - if not (self.operation.lower() == SFTPOperation.GET or - self.operation.lower() == SFTPOperation.PUT): - raise TypeError("unsupported operation value {0}, expected {1} or {2}" - .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT)) - - def execute(self, context): - file_msg = None - try: - if self.ssh_conn_id: - if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): - self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") - else: - self.log.info("ssh_hook is not provided or invalid. " + - "Trying ssh_conn_id to create SSHHook.") - self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) - - if not self.ssh_hook: - raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") - - if self.remote_host is not None: - self.log.info("remote_host is provided explicitly. " + - "It will replace the remote_host which was defined " + - "in ssh_hook or predefined in connection of ssh_conn_id.") - self.ssh_hook.remote_host = self.remote_host - - with self.ssh_hook.get_conn() as ssh_client: - sftp_client = ssh_client.open_sftp() - if self.operation.lower() == SFTPOperation.GET: - local_folder = os.path.dirname(self.local_filepath) - if self.create_intermediate_dirs: - # Create Intermediate Directories if it doesn't exist - try: - os.makedirs(local_folder) - except OSError: - if not os.path.isdir(local_folder): - raise - file_msg = "from {0} to {1}".format(self.remote_filepath, - self.local_filepath) - self.log.info("Starting to transfer %s", file_msg) - sftp_client.get(self.remote_filepath, self.local_filepath) - else: - remote_folder = os.path.dirname(self.remote_filepath) - if self.create_intermediate_dirs: - _make_intermediate_dirs( - sftp_client=sftp_client, - remote_directory=remote_folder, - ) - file_msg = "from {0} to {1}".format(self.local_filepath, - self.remote_filepath) - self.log.info("Starting to transfer file %s", file_msg) - sftp_client.put(self.local_filepath, - self.remote_filepath, - confirm=self.confirm) - - except Exception as e: - raise AirflowException("Error while transferring {0}, error: {1}" - .format(file_msg, str(e))) - - return self.local_filepath - - -def _make_intermediate_dirs(sftp_client, remote_directory): - """ - Create all the intermediate directories in a remote host - - :param sftp_client: A Paramiko SFTP client. - :param remote_directory: Absolute Path of the directory containing the file - :return: - """ - if remote_directory == '/': - sftp_client.chdir('/') - return - if remote_directory == '': - return - try: - sftp_client.chdir(remote_directory) - except OSError: - dirname, basename = os.path.split(remote_directory.rstrip('/')) - _make_intermediate_dirs(sftp_client, dirname) - sftp_client.mkdir(basename) - sftp_client.chdir(basename) - return +warnings.warn( + "This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/contrib/sensors/sftp_sensor.py b/airflow/contrib/sensors/sftp_sensor.py index f58bc27bd9..241763f9e1 100644 --- a/airflow/contrib/sensors/sftp_sensor.py +++ b/airflow/contrib/sensors/sftp_sensor.py @@ -16,40 +16,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`. +""" -from paramiko import SFTP_NO_SUCH_FILE +import warnings -from airflow.contrib.hooks.sftp_hook import SFTPHook -from airflow.sensors.base_sensor_operator import BaseSensorOperator -from airflow.utils.decorators import apply_defaults +# pylint: disable=unused-import +from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor # noqa - -class SFTPSensor(BaseSensorOperator): - """ - Waits for a file or directory to be present on SFTP. - - :param path: Remote file or directory path - :type path: str - :param sftp_conn_id: The connection to run the sensor against - :type sftp_conn_id: str - """ - template_fields = ('path',) - - @apply_defaults - def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs): - super().__init__(*args, **kwargs) - self.path = path - self.hook = None - self.sftp_conn_id = sftp_conn_id - - def poke(self, context): - self.hook = SFTPHook(self.sftp_conn_id) - self.log.info('Poking for %s', self.path) - try: - self.hook.get_mod_time(self.path) - except OSError as e: - if e.errno != SFTP_NO_SUCH_FILE: - raise e - return False - self.hook.close_conn() - return True +warnings.warn( + "This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/operators/gcs_to_sftp.py b/airflow/operators/gcs_to_sftp.py index 3fbcaaf0bb..dd302aafa8 100644 --- a/airflow/operators/gcs_to_sftp.py +++ b/airflow/operators/gcs_to_sftp.py @@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile from typing import Optional from airflow import AirflowException -from airflow.contrib.hooks.sftp_hook import SFTPHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator +from airflow.providers.sftp.hooks.sftp_hook import SFTPHook from airflow.utils.decorators import apply_defaults WILDCARD = "*" diff --git a/airflow/providers/google/cloud/operators/sftp_to_gcs.py b/airflow/providers/google/cloud/operators/sftp_to_gcs.py index 478bff1e86..7d6a6b68e6 100644 --- a/airflow/providers/google/cloud/operators/sftp_to_gcs.py +++ b/airflow/providers/google/cloud/operators/sftp_to_gcs.py @@ -23,9 +23,9 @@ from tempfile import NamedTemporaryFile from typing import Optional, Union from airflow import AirflowException -from airflow.contrib.hooks.sftp_hook import SFTPHook from airflow.gcp.hooks.gcs import GoogleCloudStorageHook from airflow.models import BaseOperator +from airflow.providers.sftp.hooks.sftp_hook import SFTPHook from airflow.utils.decorators import apply_defaults WILDCARD = "*" diff --git a/airflow/providers/sftp/__init__.py b/airflow/providers/sftp/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/airflow/providers/sftp/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/sftp/hooks/__init__.py b/airflow/providers/sftp/hooks/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/airflow/providers/sftp/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/sftp/hooks/sftp_hook.py b/airflow/providers/sftp/hooks/sftp_hook.py new file mode 100644 index 0000000000..8526798536 --- /dev/null +++ b/airflow/providers/sftp/hooks/sftp_hook.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains SFTP hook. +""" +import datetime +import stat +from typing import Dict, List, Optional, Tuple + +import pysftp + +from airflow.contrib.hooks.ssh_hook import SSHHook + + +class SFTPHook(SSHHook): + """ + This hook is inherited from SSH hook. Please refer to SSH hook for the input + arguments. + + Interact with SFTP. Aims to be interchangeable with FTPHook. + + :Pitfalls:: + + - In contrast with FTPHook describe_directory only returns size, type and + modify. It doesn't return unix.owner, unix.mode, perm, unix.group and + unique. + - retrieve_file and store_file only take a local full path and not a + buffer. + - If no mode is passed to create_directory it will be created with 777 + permissions. + + Errors that may occur throughout but should be handled downstream. + """ + + def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None: + kwargs['ssh_conn_id'] = ftp_conn_id + super().__init__(*args, **kwargs) + + self.conn = None + self.private_key_pass = None + + # Fail for unverified hosts, unless this is explicitly allowed + self.no_host_key_check = False + + if self.ssh_conn_id is not None: + conn = self.get_connection(self.ssh_conn_id) + if conn.extra is not None: + extra_options = conn.extra_dejson + if 'private_key_pass' in extra_options: + self.private_key_pass = extra_options.get('private_key_pass', None) + + # For backward compatibility + # TODO: remove in Airflow 2.1 + import warnings + if 'ignore_hostkey_verification' in extra_options: + warnings.warn( + 'Extra option `ignore_hostkey_verification` is deprecated.' + 'Please use `no_host_key_check` instead.' + 'This option will be removed in Airflow 2.1', + DeprecationWarning, + stacklevel=2, + ) + self.no_host_key_check = str( + extra_options['ignore_hostkey_verification'] + ).lower() == 'true' + + if 'no_host_key_check' in extra_options: + self.no_host_key_check = str( + extra_options['no_host_key_check']).lower() == 'true' + + if 'private_key' in extra_options: + warnings.warn( + 'Extra option `private_key` is deprecated.' + 'Please use `key_file` instead.' + 'This option will be removed in Airflow 2.1', + DeprecationWarning, + stacklevel=2, + ) + self.key_file = extra_options.get('private_key') + + def get_conn(self) -> pysftp.Connection: + """ + Returns an SFTP connection object + """ + if self.conn is None: + cnopts = pysftp.CnOpts() + if self.no_host_key_check: + cnopts.hostkeys = None + cnopts.compression = self.compress + conn_params = { + 'host': self.remote_host, + 'port': self.port, + 'username': self.username, + 'cnopts': cnopts + } + if self.password and self.password.strip(): + conn_params['password'] = self.password + if self.key_file: + conn_params['private_key'] = self.key_file + if self.private_key_pass: + conn_params['private_key_pass'] = self.private_key_pass + + self.conn = pysftp.Connection(**conn_params) + return self.conn + + def close_conn(self) -> None: + """ + Closes the connection. An error will occur if the + connection wasnt ever opened. + """ + conn = self.conn + conn.close() # type: ignore + self.conn = None + + def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]: + """ + Returns a dictionary of {filename: {attributes}} for all files + on the remote system (where the MLSD command is supported). + + :param path: full path to the remote directory + :type path: str + """ + conn = self.get_conn() + flist = conn.listdir_attr(path) + files = {} + for f in flist: + modify = datetime.datetime.fromtimestamp( + f.st_mtime).strftime('%Y%m%d%H%M%S') + files[f.filename] = { + 'size': f.st_size, + 'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file', + 'modify': modify} + return files + + def list_directory(self, path: str) -> List[str]: + """ + Returns a list of files on the remote system. + + :param path: full path to the remote directory to list + :type path: str + """ + conn = self.get_conn() + files = conn.listdir(path) + return files + + def create_directory(self, path: str, mode: int = 777) -> None: + """ + Creates a directory on the remote system. + + :param path: full path to the remote directory to create + :type path: str + :param mode: int representation of octal mode for directory + """ + conn = self.get_conn() + conn.makedirs(path, mode) + + def delete_directory(self, path: str) -> None: + """ + Deletes a directory on the remote system. + + :param path: full path to the remote directory to delete + :type path: str + """ + conn = self.get_conn() + conn.rmdir(path) + + def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None: + """ + Transfers the remote file to a local location. + If local_full_path is a string path, the file will be put + at that location + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path: full path to the local file + :type local_full_path: str + """ + conn = self.get_conn() + self.log.info('Retrieving file from FTP: %s', remote_full_path) + conn.get(remote_full_path, local_full_path) + self.log.info('Finished retrieving file from FTP: %s', remote_full_path) + + def store_file(self, remote_full_path: str, local_full_path: str) -> None: + """ + Transfers a local file to the remote location. + If local_full_path_or_buffer is a string path, the file will be read + from that location + + :param remote_full_path: full path to the remote file + :type remote_full_path: str + :param local_full_path: full path to the local file + :type local_full_path: str + """ + conn = self.get_conn() + conn.put(local_full_path, remote_full_path) + + def delete_file(self, path: str) -> None: + """ + Removes a file on the FTP Server + + :param path: full path to the remote file + :type path: str + """ + conn = self.get_conn() + conn.remove(path) + + def get_mod_time(self, path: str) -> str: + """ + Returns modification time. + + :param path: full path to the remote file + :type path: str + """ + conn = self.get_conn() + ftp_mdtm = conn.stat(path).st_mtime + return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S') + + def path_exists(self, path: str) -> bool: + """ + Returns True if a remote entity exists + + :param path: full path to the remote file or directory + :type path: str + """ + conn = self.get_conn() + return conn.exists(path) + + @staticmethod + def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool: + """ + Return True if given path starts with prefix (if set) and ends with delimiter (if set). + + :param path: path to be checked + :type path: str + :param prefix: if set path will be checked is starting with prefix + :type prefix: str + :param delimiter: if set path will be checked is ending with suffix + :type delimiter: str + :return: bool + """ + if prefix is not None and not path.startswith(prefix): + return False + if delimiter is not None and not path.endswith(delimiter): + return False + return True + + def get_tree_map( + self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None + ) -> Tuple[List[str], List[str], List[str]]: + """ + Return tuple with recursive lists of files, directories and unknown paths from given path. + It is possible to filter results by giving prefix and/or delimiter parameters. + + :param path: path from which tree will be built + :type path: str + :param prefix: if set paths will be added if start with prefix + :type prefix: str + :param delimiter: if set paths will be added if end with delimiter + :type delimiter: str + :return: tuple with list of files, dirs and unknown items + :rtype: Tuple[List[str], List[str], List[str]] + """ + conn = self.get_conn() + files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str] + + def append_matching_path_callback(list_): + return ( + lambda item: list_.append(item) + if self._is_path_match(item, prefix, delimiter) + else None + ) + + conn.walktree( + remotepath=path, + fcallback=append_matching_path_callback(files), + dcallback=append_matching_path_callback(dirs), + ucallback=append_matching_path_callback(unknowns), + recurse=True, + ) + + return files, dirs, unknowns diff --git a/airflow/providers/sftp/operators/__init__.py b/airflow/providers/sftp/operators/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/airflow/providers/sftp/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/sftp/operators/sftp_operator.py b/airflow/providers/sftp/operators/sftp_operator.py new file mode 100644 index 0000000000..a2753a91a9 --- /dev/null +++ b/airflow/providers/sftp/operators/sftp_operator.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains SFTP operator. +""" +import os + +from airflow.contrib.hooks.ssh_hook import SSHHook +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +# pylint: disable=missing-docstring +class SFTPOperation: + PUT = 'put' + GET = 'get' + + +class SFTPOperator(BaseOperator): + """ + SFTPOperator for transferring files from remote host to local or vice a versa. + This operator uses ssh_hook to open sftp transport channel that serve as basis + for file transfer. + + :param ssh_hook: predefined ssh_hook to use for remote execution. + Either `ssh_hook` or `ssh_conn_id` needs to be provided. + :type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook + :param ssh_conn_id: connection id from airflow Connections. + `ssh_conn_id` will be ignored if `ssh_hook` is provided. + :type ssh_conn_id: str + :param remote_host: remote host to connect (templated) + Nullable. If provided, it will replace the `remote_host` which was + defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`. + :type remote_host: str + :param local_filepath: local file path to get or put. (templated) + :type local_filepath: str + :param remote_filepath: remote file path to get or put. (templated) + :type remote_filepath: str + :param operation: specify operation 'get' or 'put', defaults to put + :type operation: str + :param confirm: specify if the SFTP operation should be confirmed, defaults to True + :type confirm: bool + :param create_intermediate_dirs: create missing intermediate directories when + copying from remote to local and vice-versa. Default is False. + + Example: The following task would copy ``file.txt`` to the remote host + at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they + don't exist. If the parameter is not passed it would error as the directory + does not exist. :: + + put_file = SFTPOperator( + task_id="test_sftp", + ssh_conn_id="ssh_default", + local_filepath="/tmp/file.txt", + remote_filepath="/tmp/tmp1/tmp2/file.txt", + operation="put", + create_intermediate_dirs=True, + dag=dag + ) + + :type create_intermediate_dirs: bool + """ + template_fields = ('local_filepath', 'remote_filepath', 'remote_host') + + @apply_defaults + def __init__(self, + ssh_hook=None, + ssh_conn_id=None, + remote_host=None, + local_filepath=None, + remote_filepath=None, + operation=SFTPOperation.PUT, + confirm=True, + create_intermediate_dirs=False, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.ssh_hook = ssh_hook + self.ssh_conn_id = ssh_conn_id + self.remote_host = remote_host + self.local_filepath = local_filepath + self.remote_filepath = remote_filepath + self.operation = operation + self.confirm = confirm + self.create_intermediate_dirs = create_intermediate_dirs + if not (self.operation.lower() == SFTPOperation.GET or + self.operation.lower() == SFTPOperation.PUT): + raise TypeError("unsupported operation value {0}, expected {1} or {2}" + .format(self.operation, SFTPOperation.GET, SFTPOperation.PUT)) + + def execute(self, context): + file_msg = None + try: + if self.ssh_conn_id: + if self.ssh_hook and isinstance(self.ssh_hook, SSHHook): + self.log.info("ssh_conn_id is ignored when ssh_hook is provided.") + else: + self.log.info("ssh_hook is not provided or invalid. " + "Trying ssh_conn_id to create SSHHook.") + self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id) + + if not self.ssh_hook: + raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.") + + if self.remote_host is not None: + self.log.info("remote_host is provided explicitly. " + "It will replace the remote_host which was defined " + "in ssh_hook or predefined in connection of ssh_conn_id.") + self.ssh_hook.remote_host = self.remote_host + + with self.ssh_hook.get_conn() as ssh_client: + sftp_client = ssh_client.open_sftp() + if self.operation.lower() == SFTPOperation.GET: + local_folder = os.path.dirname(self.local_filepath) + if self.create_intermediate_dirs: + # Create Intermediate Directories if it doesn't exist + try: + os.makedirs(local_folder) + except OSError: + if not os.path.isdir(local_folder): + raise + file_msg = "from {0} to {1}".format(self.remote_filepath, + self.local_filepath) + self.log.info("Starting to transfer %s", file_msg) + sftp_client.get(self.remote_filepath, self.local_filepath) + else: + remote_folder = os.path.dirname(self.remote_filepath) + if self.create_intermediate_dirs: + _make_intermediate_dirs( + sftp_client=sftp_client, + remote_directory=remote_folder, + ) + file_msg = "from {0} to {1}".format(self.local_filepath, + self.remote_filepath) + self.log.info("Starting to transfer file %s", file_msg) + sftp_client.put(self.local_filepath, + self.remote_filepath, + confirm=self.confirm) + + except Exception as e: + raise AirflowException("Error while transferring {0}, error: {1}" + .format(file_msg, str(e))) + + return self.local_filepath + + +def _make_intermediate_dirs(sftp_client, remote_directory): + """ + Create all the intermediate directories in a remote host + + :param sftp_client: A Paramiko SFTP client. + :param remote_directory: Absolute Path of the directory containing the file + :return: + """ + if remote_directory == '/': + sftp_client.chdir('/') + return + if remote_directory == '': + return + try: + sftp_client.chdir(remote_directory) + except OSError: + dirname, basename = os.path.split(remote_directory.rstrip('/')) + _make_intermediate_dirs(sftp_client, dirname) + sftp_client.mkdir(basename) + sftp_client.chdir(basename) + return diff --git a/airflow/providers/sftp/sensors/__init__.py b/airflow/providers/sftp/sensors/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/airflow/providers/sftp/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/sftp/sensors/sftp_sensor.py b/airflow/providers/sftp/sensors/sftp_sensor.py new file mode 100644 index 0000000000..a3bb0d6ee9 --- /dev/null +++ b/airflow/providers/sftp/sensors/sftp_sensor.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains SFTP sensor. +""" +from paramiko import SFTP_NO_SUCH_FILE + +from airflow.providers.sftp.hooks.sftp_hook import SFTPHook +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils.decorators import apply_defaults + + +class SFTPSensor(BaseSensorOperator): + """ + Waits for a file or directory to be present on SFTP. + + :param path: Remote file or directory path + :type path: str + :param sftp_conn_id: The connection to run the sensor against + :type sftp_conn_id: str + """ + template_fields = ('path',) + + @apply_defaults + def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs): + super().__init__(*args, **kwargs) + self.path = path + self.hook = None + self.sftp_conn_id = sftp_conn_id + + def poke(self, context): + self.hook = SFTPHook(self.sftp_conn_id) + self.log.info('Poking for %s', self.path) + try: + self.hook.get_mod_time(self.path) + except OSError as e: + if e.errno != SFTP_NO_SUCH_FILE: + raise e + return False + self.hook.close_conn() + return True diff --git a/docs/autoapi_templates/index.rst b/docs/autoapi_templates/index.rst index 114ee88071..8f39072098 100644 --- a/docs/autoapi_templates/index.rst +++ b/docs/autoapi_templates/index.rst @@ -76,6 +76,8 @@ All operators are in the following packages: airflow/providers/amazon/aws/sensors/index + airflow/providers/apache/cassandra/sensors/index + airflow/providers/google/cloud/operators/index airflow/providers/google/cloud/sensors/index @@ -84,7 +86,9 @@ All operators are in the following packages: airflow/providers/google/marketing_platform/sensors/index - airflow/providers/apache/cassandra/sensors/index + airflow/providers/sftp/operators/index + + airflow/providers/sftp/sensors/index Hooks ----- @@ -117,6 +121,8 @@ All hooks are in the following packages: airflow/providers/apache/cassandra/hooks/index + airflow/providers/sftp/hooks/index + Executors --------- Executors are the mechanism by which task instances get run. All executors are diff --git a/docs/conf.py b/docs/conf.py index 96718a2902..90e3857573 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -233,6 +233,7 @@ exclude_patterns = [ '_api/airflow/providers/amazon/aws/example_dags', '_api/airflow/providers/apache/index.rst', '_api/airflow/providers/apache/cassandra/index.rst', + '_api/airflow/providers/sftp/index.rst', '_api/enums/index.rst', '_api/json_schema/index.rst', '_api/base_serialization/index.rst', diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index 48e8ed723f..50e82c232d 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -1246,9 +1246,9 @@ communication protocols or interface. * - `SSH File Transfer Protocol (SFTP) `__ - - - :mod:`airflow.contrib.hooks.sftp_hook` - - :mod:`airflow.contrib.operators.sftp_operator` - - :mod:`airflow.contrib.sensors.sftp_sensor` + - :mod:`airflow.providers.sftp.hooks.sftp_hook` + - :mod:`airflow.providers.sftp.operators.sftp_operator` + - :mod:`airflow.providers.sftp.sensors.sftp_sensor` * - `Secure Shell (SSH) `__ - diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index 2de583bce7..116f44ab15 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -21,7 +21,6 @@ ./airflow/contrib/hooks/qubole_check_hook.py ./airflow/contrib/hooks/sagemaker_hook.py ./airflow/contrib/hooks/segment_hook.py -./airflow/contrib/hooks/sftp_hook.py ./airflow/contrib/hooks/slack_webhook_hook.py ./airflow/contrib/hooks/snowflake_hook.py ./airflow/contrib/hooks/spark_jdbc_hook.py @@ -65,7 +64,6 @@ ./airflow/contrib/operators/sagemaker_transform_operator.py ./airflow/contrib/operators/sagemaker_tuning_operator.py ./airflow/contrib/operators/segment_track_event_operator.py -./airflow/contrib/operators/sftp_operator.py ./airflow/contrib/operators/sftp_to_s3_operator.py ./airflow/contrib/operators/slack_webhook_operator.py ./airflow/contrib/operators/snowflake_operator.py @@ -100,7 +98,6 @@ ./airflow/contrib/sensors/sagemaker_training_sensor.py ./airflow/contrib/sensors/sagemaker_transform_sensor.py ./airflow/contrib/sensors/sagemaker_tuning_sensor.py -./airflow/contrib/sensors/sftp_sensor.py ./airflow/contrib/sensors/wasb_sensor.py ./airflow/contrib/sensors/weekday_sensor.py ./airflow/hooks/dbapi_hook.py diff --git a/tests/contrib/hooks/test_sftp_hook.py b/tests/contrib/hooks/test_sftp_hook.py index 0204038b79..ed1f250ca8 100644 --- a/tests/contrib/hooks/test_sftp_hook.py +++ b/tests/contrib/hooks/test_sftp_hook.py @@ -17,231 +17,17 @@ # specific language governing permissions and limitations # under the License. -import os -import shutil -import unittest -from unittest import mock +import importlib +from unittest import TestCase -import pysftp -from parameterized import parameterized - -from airflow.contrib.hooks.sftp_hook import SFTPHook -from airflow.models import Connection -from airflow.utils.db import provide_session - -TMP_PATH = '/tmp' -TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' -SUB_DIR = "sub_dir" -TMP_FILE_FOR_TESTS = 'test_file.txt' - -SFTP_CONNECTION_USER = "root" +OLD_PATH = "airflow.contrib.hooks.sftp_hook" +NEW_PATH = "airflow.providers.sftp.hooks.sftp_hook" +WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH) -class TestSFTPHook(unittest.TestCase): - - @provide_session - def update_connection(self, login, session=None): - connection = (session.query(Connection). - filter(Connection.conn_id == "sftp_default") - .first()) - old_login = connection.login - connection.login = login - session.commit() - return old_login - - def setUp(self): - self.old_login = self.update_connection(SFTP_CONNECTION_USER) - self.hook = SFTPHook() - os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) - - with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: - file.write('Test file') - with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: - file.write('Test file') - - def test_get_conn(self): - output = self.hook.get_conn() - self.assertEqual(type(output), pysftp.Connection) - - def test_close_conn(self): - self.hook.conn = self.hook.get_conn() - self.assertTrue(self.hook.conn is not None) - self.hook.close_conn() - self.assertTrue(self.hook.conn is None) - - def test_describe_directory(self): - output = self.hook.describe_directory(TMP_PATH) - self.assertTrue(TMP_DIR_FOR_TESTS in output) - - def test_list_directory(self): - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertEqual(output, [SUB_DIR]) - - def test_create_and_delete_directory(self): - new_dir_name = 'new_dir' - self.hook.create_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertTrue(new_dir_name in output) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertTrue(new_dir_name not in output) - - def test_create_and_delete_directories(self): - base_dir = "base_dir" - sub_dir = "sub_dir" - new_dir_path = os.path.join(base_dir, sub_dir) - self.hook.create_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertTrue(base_dir in output) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) - self.assertTrue(sub_dir in output) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) - self.hook.delete_directory(os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) - output = self.hook.describe_directory( - os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertTrue(new_dir_path not in output) - self.assertTrue(base_dir not in output) - - def test_store_retrieve_and_delete_file(self): - self.hook.store_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) - ) - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS]) - retrieved_file_name = 'retrieved.txt' - self.hook.retrieve_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, retrieved_file_name) - ) - self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH)) - os.remove(os.path.join(TMP_PATH, retrieved_file_name)) - self.hook.delete_file(path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) - output = self.hook.list_directory( - path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - self.assertEqual(output, [SUB_DIR]) - - def test_get_mod_time(self): - self.hook.store_file( - remote_full_path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), - local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) - ) - output = self.hook.get_mod_time(path=os.path.join( - TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) - self.assertEqual(len(output), 14) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_default(self, get_connection): - connection = Connection(login='login', host='host') - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, False) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_enabled(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": true}') - - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, True) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_disabled(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": false}') - - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, False) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"no_host_key_check": "foo"}') - - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, False) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_ignore(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"ignore_hostkey_verification": true}') - - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, True) - - @mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection') - def test_no_host_key_check_no_ignore(self, get_connection): - connection = Connection( - login='login', host='host', - extra='{"ignore_hostkey_verification": false}') - - get_connection.return_value = connection - hook = SFTPHook() - self.assertEqual(hook.no_host_key_check, False) - - @parameterized.expand([ - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), - (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), - (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), - ]) - def test_path_exists(self, path, exists): - result = self.hook.path_exists(path) - self.assertEqual(result, exists) - - @parameterized.expand([ - ("test/path/file.bin", None, None, True), - ("test/path/file.bin", "test", None, True), - ("test/path/file.bin", "test/", None, True), - ("test/path/file.bin", None, "bin", True), - ("test/path/file.bin", "test", "bin", True), - ("test/path/file.bin", "test/", "file.bin", True), - ("test/path/file.bin", None, "file.bin", True), - ("test/path/file.bin", "diff", None, False), - ("test/path/file.bin", "test//", None, False), - ("test/path/file.bin", None, ".txt", False), - ("test/path/file.bin", "diff", ".txt", False), - ]) - def test_path_match(self, path, prefix, delimiter, match): - result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) - self.assertEqual(result, match) - - def test_get_tree_map(self): - tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - files, dirs, unknowns = tree_map - - self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)]) - self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]) - self.assertEqual(unknowns, []) - - def tearDown(self): - shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) - os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) - self.update_connection(self.old_login) - - -if __name__ == '__main__': - unittest.main() +class TestMovingSFTPHookToCore(TestCase): + def test_move_sftp_hook_to_core(self): + with self.assertWarns(DeprecationWarning) as warn: + # Reload to see deprecation warning each time + importlib.import_module(OLD_PATH) + self.assertEqual(WARNING_MESSAGE, str(warn.warning)) diff --git a/tests/contrib/operators/test_sftp_operator.py b/tests/contrib/operators/test_sftp_operator.py index 4b554eab51..cd2c889760 100644 --- a/tests/contrib/operators/test_sftp_operator.py +++ b/tests/contrib/operators/test_sftp_operator.py @@ -17,436 +17,17 @@ # specific language governing permissions and limitations # under the License. -import os -import unittest -from base64 import b64encode -from unittest import mock +import importlib +from unittest import TestCase -from airflow import AirflowException -from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator -from airflow.contrib.operators.ssh_operator import SSHOperator -from airflow.models import DAG, TaskInstance -from airflow.utils import timezone -from airflow.utils.timezone import datetime -from tests.test_utils.config import conf_vars - -TEST_DAG_ID = 'unit_tests_sftp_op' -DEFAULT_DATE = datetime(2017, 1, 1) -TEST_CONN_ID = "conn_id_for_testing" +OLD_PATH = "airflow.contrib.operators.sftp_operator" +NEW_PATH = "airflow.providers.sftp.operators.sftp_operator" +WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH) -class TestSFTPOperator(unittest.TestCase): - def setUp(self): - from airflow.contrib.hooks.ssh_hook import SSHHook - - hook = SSHHook(ssh_conn_id='ssh_default') - hook.no_host_key_check = True - args = { - 'owner': 'airflow', - 'start_date': DEFAULT_DATE, - } - dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) - dag.schedule_interval = '@once' - self.hook = hook - self.dag = dag - self.test_dir = "/tmp" - self.test_local_dir = "/tmp/tmp2" - self.test_remote_dir = "/tmp/tmp1" - self.test_local_filename = 'test_local_file' - self.test_remote_filename = 'test_remote_file' - self.test_local_filepath = '{0}/{1}'.format(self.test_dir, - self.test_local_filename) - # Local Filepath with Intermediate Directory - self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir, - self.test_local_filename) - self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, - self.test_remote_filename) - # Remote Filepath with Intermediate Directory - self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir, - self.test_remote_filename) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_file_transfer_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ - b"continuing....with other character\nanother line here \n this is last line" - # create a test file locally - with open(self.test_local_filepath, 'wb') as file: - file.write(test_local_file_content) - - # put test file to remote - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - create_intermediate_dirs=True, - dag=self.dag - ) - self.assertIsNotNone(put_test_task) - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # check the remote file content - check_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="cat {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(check_file_task) - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - self.assertEqual( - ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), - test_local_file_content) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_no_intermediate_dir_error_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ - b"continuing....with other character\nanother line here \n this is last line" - # create a test file locally - with open(self.test_local_filepath, 'wb') as file: - file.write(test_local_file_content) - - # Try to put test file to remote - # This should raise an error with "No such file" as the directory - # does not exist - with self.assertRaises(Exception) as error: - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath_int_dir, - operation=SFTPOperation.PUT, - create_intermediate_dirs=False, - dag=self.dag - ) - self.assertIsNotNone(put_test_task) - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() - self.assertIn('No such file', str(error.exception)) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_with_intermediate_dir_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ - b"continuing....with other character\nanother line here \n this is last line" - # create a test file locally - with open(self.test_local_filepath, 'wb') as file: - file.write(test_local_file_content) - - # put test file to remote - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath_int_dir, - operation=SFTPOperation.PUT, - create_intermediate_dirs=True, - dag=self.dag - ) - self.assertIsNotNone(put_test_task) - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # check the remote file content - check_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="cat {0}".format(self.test_remote_filepath_int_dir), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(check_file_task) - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - self.assertEqual( - ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), - test_local_file_content) - - @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_file_transfer_put(self): - test_local_file_content = \ - b"This is local file content \n which is multiline " \ - b"continuing....with other character\nanother line here \n this is last line" - # create a test file locally - with open(self.test_local_filepath, 'wb') as file: - file.write(test_local_file_content) - - # put test file to remote - put_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag - ) - self.assertIsNotNone(put_test_task) - ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # check the remote file content - check_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="cat {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(check_file_task) - ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) - ti3.run() - self.assertEqual( - ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), - b64encode(test_local_file_content).decode('utf-8')) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_pickle_file_transfer_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ - "another line here \n this is last line. EOF" - - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(create_file_task) - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() - - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - dag=self.dag - ) - self.assertIsNotNone(get_test_task) - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # test the received content - content_received = None - with open(self.test_local_filepath, 'r') as file: - content_received = file.read() - self.assertEqual(content_received.strip(), test_remote_file_content) - - @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) - def test_json_file_transfer_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ - "another line here \n this is last line. EOF" - - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(create_file_task) - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() - - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - dag=self.dag - ) - self.assertIsNotNone(get_test_task) - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # test the received content - content_received = None - with open(self.test_local_filepath, 'r') as file: - content_received = file.read() - self.assertEqual(content_received.strip(), - test_remote_file_content.encode('utf-8').decode('utf-8')) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_no_intermediate_dir_error_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ - "another line here \n this is last line. EOF" - - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(create_file_task) - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() - - # Try to GET test file from remote - # This should raise an error with "No such file" as the directory - # does not exist - with self.assertRaises(Exception) as error: - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath_int_dir, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - dag=self.dag - ) - self.assertIsNotNone(get_test_task) - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() - self.assertIn('No such file', str(error.exception)) - - @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) - def test_file_transfer_with_intermediate_dir_error_get(self): - test_remote_file_content = \ - "This is remote file content \n which is also multiline " \ - "another line here \n this is last line. EOF" - - # create a test file remotely - create_file_task = SSHOperator( - task_id="test_create_file", - ssh_hook=self.hook, - command="echo '{0}' > {1}".format(test_remote_file_content, - self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(create_file_task) - ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) - ti1.run() - - # get remote file to local - get_test_task = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - local_filepath=self.test_local_filepath_int_dir, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.GET, - create_intermediate_dirs=True, - dag=self.dag - ) - self.assertIsNotNone(get_test_task) - ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) - ti2.run() - - # test the received content - content_received = None - with open(self.test_local_filepath_int_dir, 'r') as file: - content_received = file.read() - self.assertEqual(content_received.strip(), test_remote_file_content) - - @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) - def test_arg_checking(self): - # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided - with self.assertRaisesRegex(AirflowException, - "Cannot operate without ssh_hook or ssh_conn_id."): - task_0 = SFTPOperator( - task_id="test_sftp", - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag - ) - task_0.execute(None) - - # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook - task_1 = SFTPOperator( - task_id="test_sftp", - ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook - ssh_conn_id=TEST_CONN_ID, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag - ) - try: - task_1.execute(None) - except Exception: - pass - self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID) - - task_2 = SFTPOperator( - task_id="test_sftp", - ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag - ) - try: - task_2.execute(None) - except Exception: - pass - self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID) - - # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id - task_3 = SFTPOperator( - task_id="test_sftp", - ssh_hook=self.hook, - ssh_conn_id=TEST_CONN_ID, - local_filepath=self.test_local_filepath, - remote_filepath=self.test_remote_filepath, - operation=SFTPOperation.PUT, - dag=self.dag - ) - try: - task_3.execute(None) - except Exception: - pass - self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) - - def delete_local_resource(self): - if os.path.exists(self.test_local_filepath): - os.remove(self.test_local_filepath) - if os.path.exists(self.test_local_filepath_int_dir): - os.remove(self.test_local_filepath_int_dir) - if os.path.exists(self.test_local_dir): - os.rmdir(self.test_local_dir) - - def delete_remote_resource(self): - if os.path.exists(self.test_remote_filepath): - # check the remote file content - remove_file_task = SSHOperator( - task_id="test_check_file", - ssh_hook=self.hook, - command="rm {0}".format(self.test_remote_filepath), - do_xcom_push=True, - dag=self.dag - ) - self.assertIsNotNone(remove_file_task) - ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) - ti3.run() - if os.path.exists(self.test_remote_filepath_int_dir): - os.remove(self.test_remote_filepath_int_dir) - if os.path.exists(self.test_remote_dir): - os.rmdir(self.test_remote_dir) - - def tearDown(self): - self.delete_local_resource() - self.delete_remote_resource() - - -if __name__ == '__main__': - unittest.main() +class TestMovingSFTPHookToCore(TestCase): + def test_move_sftp_operator_to_core(self): + with self.assertWarns(DeprecationWarning) as warn: + # Reload to see deprecation warning each time + importlib.import_module(OLD_PATH) + self.assertEqual(WARNING_MESSAGE, str(warn.warning)) diff --git a/tests/contrib/sensors/test_sftp_sensor.py b/tests/contrib/sensors/test_sftp_sensor.py index 351dac8408..f3f8c2c842 100644 --- a/tests/contrib/sensors/test_sftp_sensor.py +++ b/tests/contrib/sensors/test_sftp_sensor.py @@ -17,61 +17,17 @@ # specific language governing permissions and limitations # under the License. -import unittest -from unittest.mock import patch +import importlib +from unittest import TestCase -from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE - -from airflow.contrib.sensors.sftp_sensor import SFTPSensor +OLD_PATH = "airflow.contrib.sensors.sftp_sensor" +NEW_PATH = "airflow.providers.sftp.sensors.sftp_sensor" +WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH) -class TestSFTPSensor(unittest.TestCase): - @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook') - def test_file_present(self, sftp_hook_mock): - sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } - output = sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') - self.assertTrue(output) - - @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook') - def test_file_absent(self, sftp_hook_mock): - sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( - SFTP_NO_SUCH_FILE, 'File missing') - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } - output = sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') - self.assertFalse(output) - - @patch('airflow.contrib.sensors.sftp_sensor.SFTPHook') - def test_sftp_failure(self, sftp_hook_mock): - sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( - SFTP_FAILURE, 'SFTP failure') - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - context = { - 'ds': '1970-01-01' - } - with self.assertRaises(OSError): - sftp_sensor.poke(context) - sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( - '/path/to/file/1970-01-01.txt') - - def test_hook_not_created_during_init(self): - sftp_sensor = SFTPSensor( - task_id='unit_test', - path='/path/to/file/1970-01-01.txt') - self.assertIsNone(sftp_sensor.hook) +class TestMovingSFTPHookToCore(TestCase): + def test_move_sftp_sensor_to_core(self): + with self.assertWarns(DeprecationWarning) as warn: + # Reload to see deprecation warning each time + importlib.import_module(OLD_PATH) + self.assertEqual(WARNING_MESSAGE, str(warn.warning)) diff --git a/tests/providers/sftp/__init__.py b/tests/providers/sftp/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/providers/sftp/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/sftp/hooks/__init__.py b/tests/providers/sftp/hooks/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/providers/sftp/hooks/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/sftp/hooks/test_sftp_hook.py b/tests/providers/sftp/hooks/test_sftp_hook.py new file mode 100644 index 0000000000..ce71994774 --- /dev/null +++ b/tests/providers/sftp/hooks/test_sftp_hook.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import shutil +import unittest +from unittest import mock + +import pysftp +from parameterized import parameterized + +from airflow.models import Connection +from airflow.providers.sftp.hooks.sftp_hook import SFTPHook +from airflow.utils.db import provide_session + +TMP_PATH = '/tmp' +TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir' +SUB_DIR = "sub_dir" +TMP_FILE_FOR_TESTS = 'test_file.txt' + +SFTP_CONNECTION_USER = "root" + + +class TestSFTPHook(unittest.TestCase): + + @provide_session + def update_connection(self, login, session=None): + connection = (session.query(Connection). + filter(Connection.conn_id == "sftp_default") + .first()) + old_login = connection.login + connection.login = login + session.commit() + return old_login + + def setUp(self): + self.old_login = self.update_connection(SFTP_CONNECTION_USER) + self.hook = SFTPHook() + os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) + + with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file: + file.write('Test file') + with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file: + file.write('Test file') + + def test_get_conn(self): + output = self.hook.get_conn() + self.assertEqual(type(output), pysftp.Connection) + + def test_close_conn(self): + self.hook.conn = self.hook.get_conn() + self.assertTrue(self.hook.conn is not None) + self.hook.close_conn() + self.assertTrue(self.hook.conn is None) + + def test_describe_directory(self): + output = self.hook.describe_directory(TMP_PATH) + self.assertTrue(TMP_DIR_FOR_TESTS in output) + + def test_list_directory(self): + output = self.hook.list_directory( + path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertEqual(output, [SUB_DIR]) + + def test_create_and_delete_directory(self): + new_dir_name = 'new_dir' + self.hook.create_directory(os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory( + os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertTrue(new_dir_name in output) + self.hook.delete_directory(os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name)) + output = self.hook.describe_directory( + os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertTrue(new_dir_name not in output) + + def test_create_and_delete_directories(self): + base_dir = "base_dir" + sub_dir = "sub_dir" + new_dir_path = os.path.join(base_dir, sub_dir) + self.hook.create_directory(os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + output = self.hook.describe_directory( + os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertTrue(base_dir in output) + output = self.hook.describe_directory( + os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + self.assertTrue(sub_dir in output) + self.hook.delete_directory(os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path)) + self.hook.delete_directory(os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, base_dir)) + output = self.hook.describe_directory( + os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertTrue(new_dir_path not in output) + self.assertTrue(base_dir not in output) + + def test_store_retrieve_and_delete_file(self): + self.hook.store_file( + remote_full_path=os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) + ) + output = self.hook.list_directory( + path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS]) + retrieved_file_name = 'retrieved.txt' + self.hook.retrieve_file( + remote_full_path=os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, retrieved_file_name) + ) + self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH)) + os.remove(os.path.join(TMP_PATH, retrieved_file_name)) + self.hook.delete_file(path=os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + output = self.hook.list_directory( + path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + self.assertEqual(output, [SUB_DIR]) + + def test_get_mod_time(self): + self.hook.store_file( + remote_full_path=os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS), + local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS) + ) + output = self.hook.get_mod_time(path=os.path.join( + TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS)) + self.assertEqual(len(output), 14) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_default(self, get_connection): + connection = Connection(login='login', host='host') + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_enabled(self, get_connection): + connection = Connection( + login='login', host='host', + extra='{"no_host_key_check": true}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, True) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_disabled(self, get_connection): + connection = Connection( + login='login', host='host', + extra='{"no_host_key_check": false}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_disabled_for_all_but_true(self, get_connection): + connection = Connection( + login='login', host='host', + extra='{"no_host_key_check": "foo"}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_ignore(self, get_connection): + connection = Connection( + login='login', host='host', + extra='{"ignore_hostkey_verification": true}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, True) + + @mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection') + def test_no_host_key_check_no_ignore(self, get_connection): + connection = Connection( + login='login', host='host', + extra='{"ignore_hostkey_verification": false}') + + get_connection.return_value = connection + hook = SFTPHook() + self.assertEqual(hook.no_host_key_check, False) + + @parameterized.expand([ + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), + (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), + ]) + def test_path_exists(self, path, exists): + result = self.hook.path_exists(path) + self.assertEqual(result, exists) + + @parameterized.expand([ + ("test/path/file.bin", None, None, True), + ("test/path/file.bin", "test", None, True), + ("test/path/file.bin", "test/", None, True), + ("test/path/file.bin", None, "bin", True), + ("test/path/file.bin", "test", "bin", True), + ("test/path/file.bin", "test/", "file.bin", True), + ("test/path/file.bin", None, "file.bin", True), + ("test/path/file.bin", "diff", None, False), + ("test/path/file.bin", "test//", None, False), + ("test/path/file.bin", None, ".txt", False), + ("test/path/file.bin", "diff", ".txt", False), + ]) + def test_path_match(self, path, prefix, delimiter, match): + result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) + self.assertEqual(result, match) + + def test_get_tree_map(self): + tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + files, dirs, unknowns = tree_map + + self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)]) + self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)]) + self.assertEqual(unknowns, []) + + def tearDown(self): + shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) + os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)) + self.update_connection(self.old_login) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/providers/sftp/operators/__init__.py b/tests/providers/sftp/operators/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/providers/sftp/operators/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/sftp/operators/test_sftp_operator.py b/tests/providers/sftp/operators/test_sftp_operator.py new file mode 100644 index 0000000000..e05d138a27 --- /dev/null +++ b/tests/providers/sftp/operators/test_sftp_operator.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import unittest +from base64 import b64encode +from unittest import mock + +from airflow import AirflowException +from airflow.contrib.operators.ssh_operator import SSHOperator +from airflow.models import DAG, TaskInstance +from airflow.providers.sftp.operators.sftp_operator import SFTPOperation, SFTPOperator +from airflow.utils import timezone +from airflow.utils.timezone import datetime +from tests.test_utils.config import conf_vars + +TEST_DAG_ID = 'unit_tests_sftp_op' +DEFAULT_DATE = datetime(2017, 1, 1) +TEST_CONN_ID = "conn_id_for_testing" + + +class TestSFTPOperator(unittest.TestCase): + def setUp(self): + from airflow.contrib.hooks.ssh_hook import SSHHook + + hook = SSHHook(ssh_conn_id='ssh_default') + hook.no_host_key_check = True + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE, + } + dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args) + dag.schedule_interval = '@once' + self.hook = hook + self.dag = dag + self.test_dir = "/tmp" + self.test_local_dir = "/tmp/tmp2" + self.test_remote_dir = "/tmp/tmp1" + self.test_local_filename = 'test_local_file' + self.test_remote_filename = 'test_remote_file' + self.test_local_filepath = '{0}/{1}'.format(self.test_dir, + self.test_local_filename) + # Local Filepath with Intermediate Directory + self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir, + self.test_local_filename) + self.test_remote_filepath = '{0}/{1}'.format(self.test_dir, + self.test_remote_filename) + # Remote Filepath with Intermediate Directory + self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir, + self.test_remote_filename) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_pickle_file_transfer_put(self): + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as file: + file.write(test_local_file_content) + + # put test file to remote + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + create_intermediate_dirs=True, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # check the remote file content + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + test_local_file_content) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_file_transfer_no_intermediate_dir_error_put(self): + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as file: + file.write(test_local_file_content) + + # Try to put test file to remote + # This should raise an error with "No such file" as the directory + # does not exist + with self.assertRaises(Exception) as error: + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=False, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + self.assertIn('No such file', str(error.exception)) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_file_transfer_with_intermediate_dir_put(self): + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as file: + file.write(test_local_file_content) + + # put test file to remote + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath_int_dir, + operation=SFTPOperation.PUT, + create_intermediate_dirs=True, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # check the remote file content + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.test_remote_filepath_int_dir), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + test_local_file_content) + + @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) + def test_json_file_transfer_put(self): + test_local_file_content = \ + b"This is local file content \n which is multiline " \ + b"continuing....with other character\nanother line here \n this is last line" + # create a test file locally + with open(self.test_local_filepath, 'wb') as file: + file.write(test_local_file_content) + + # put test file to remote + put_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + self.assertIsNotNone(put_test_task) + ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # check the remote file content + check_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="cat {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(check_file_task) + ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow()) + ti3.run() + self.assertEqual( + ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(), + b64encode(test_local_file_content).decode('utf-8')) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_pickle_file_transfer_get(self): + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # get remote file to local + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # test the received content + content_received = None + with open(self.test_local_filepath, 'r') as file: + content_received = file.read() + self.assertEqual(content_received.strip(), test_remote_file_content) + + @conf_vars({('core', 'enable_xcom_pickling'): 'False'}) + def test_json_file_transfer_get(self): + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # get remote file to local + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # test the received content + content_received = None + with open(self.test_local_filepath, 'r') as file: + content_received = file.read() + self.assertEqual(content_received.strip(), + test_remote_file_content.encode('utf-8').decode('utf-8')) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_file_transfer_no_intermediate_dir_error_get(self): + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # Try to GET test file from remote + # This should raise an error with "No such file" as the directory + # does not exist + with self.assertRaises(Exception) as error: + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath_int_dir, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + self.assertIn('No such file', str(error.exception)) + + @conf_vars({('core', 'enable_xcom_pickling'): 'True'}) + def test_file_transfer_with_intermediate_dir_error_get(self): + test_remote_file_content = \ + "This is remote file content \n which is also multiline " \ + "another line here \n this is last line. EOF" + + # create a test file remotely + create_file_task = SSHOperator( + task_id="test_create_file", + ssh_hook=self.hook, + command="echo '{0}' > {1}".format(test_remote_file_content, + self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(create_file_task) + ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow()) + ti1.run() + + # get remote file to local + get_test_task = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + local_filepath=self.test_local_filepath_int_dir, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.GET, + create_intermediate_dirs=True, + dag=self.dag + ) + self.assertIsNotNone(get_test_task) + ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow()) + ti2.run() + + # test the received content + content_received = None + with open(self.test_local_filepath_int_dir, 'r') as file: + content_received = file.read() + self.assertEqual(content_received.strip(), test_remote_file_content) + + @mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"}) + def test_arg_checking(self): + # Exception should be raised if neither ssh_hook nor ssh_conn_id is provided + with self.assertRaisesRegex(AirflowException, + "Cannot operate without ssh_hook or ssh_conn_id."): + task_0 = SFTPOperator( + task_id="test_sftp", + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + task_0.execute(None) + + # if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook + task_1 = SFTPOperator( + task_id="test_sftp", + ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook + ssh_conn_id=TEST_CONN_ID, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_1.execute(None) + except Exception: # pylint: disable=broad-except + pass + self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID) + + task_2 = SFTPOperator( + task_id="test_sftp", + ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_2.execute(None) + except Exception: # pylint: disable=broad-except + pass + self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID) + + # if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id + task_3 = SFTPOperator( + task_id="test_sftp", + ssh_hook=self.hook, + ssh_conn_id=TEST_CONN_ID, + local_filepath=self.test_local_filepath, + remote_filepath=self.test_remote_filepath, + operation=SFTPOperation.PUT, + dag=self.dag + ) + try: + task_3.execute(None) + except Exception: # pylint: disable=broad-except + pass + self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id) + + def delete_local_resource(self): + if os.path.exists(self.test_local_filepath): + os.remove(self.test_local_filepath) + if os.path.exists(self.test_local_filepath_int_dir): + os.remove(self.test_local_filepath_int_dir) + if os.path.exists(self.test_local_dir): + os.rmdir(self.test_local_dir) + + def delete_remote_resource(self): + if os.path.exists(self.test_remote_filepath): + # check the remote file content + remove_file_task = SSHOperator( + task_id="test_check_file", + ssh_hook=self.hook, + command="rm {0}".format(self.test_remote_filepath), + do_xcom_push=True, + dag=self.dag + ) + self.assertIsNotNone(remove_file_task) + ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow()) + ti3.run() + if os.path.exists(self.test_remote_filepath_int_dir): + os.remove(self.test_remote_filepath_int_dir) + if os.path.exists(self.test_remote_dir): + os.rmdir(self.test_remote_dir) + + def tearDown(self): + self.delete_local_resource() + self.delete_remote_resource() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/providers/sftp/sensors/__init__.py b/tests/providers/sftp/sensors/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/tests/providers/sftp/sensors/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/sftp/sensors/test_sftp_sensor.py b/tests/providers/sftp/sensors/test_sftp_sensor.py new file mode 100644 index 0000000000..163f72073a --- /dev/null +++ b/tests/providers/sftp/sensors/test_sftp_sensor.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from unittest.mock import patch + +from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE + +from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor + + +class TestSFTPSensor(unittest.TestCase): + @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook') + def test_file_present(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000' + sftp_sensor = SFTPSensor( + task_id='unit_test', + path='/path/to/file/1970-01-01.txt') + context = { + 'ds': '1970-01-01' + } + output = sftp_sensor.poke(context) + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( + '/path/to/file/1970-01-01.txt') + self.assertTrue(output) + + @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook') + def test_file_absent(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( + SFTP_NO_SUCH_FILE, 'File missing') + sftp_sensor = SFTPSensor( + task_id='unit_test', + path='/path/to/file/1970-01-01.txt') + context = { + 'ds': '1970-01-01' + } + output = sftp_sensor.poke(context) + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( + '/path/to/file/1970-01-01.txt') + self.assertFalse(output) + + @patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook') + def test_sftp_failure(self, sftp_hook_mock): + sftp_hook_mock.return_value.get_mod_time.side_effect = OSError( + SFTP_FAILURE, 'SFTP failure') + sftp_sensor = SFTPSensor( + task_id='unit_test', + path='/path/to/file/1970-01-01.txt') + context = { + 'ds': '1970-01-01' + } + with self.assertRaises(OSError): + sftp_sensor.poke(context) + sftp_hook_mock.return_value.get_mod_time.assert_called_once_with( + '/path/to/file/1970-01-01.txt') + + def test_hook_not_created_during_init(self): + sftp_sensor = SFTPSensor( + task_id='unit_test', + path='/path/to/file/1970-01-01.txt') + self.assertIsNone(sftp_sensor.hook)