[AIRFLOW-5807] Move SFTP from contrib to providers. (#6464)
* [AIRFLOW-5807] Move SFTP from contrib to core
This commit is contained in:
Родитель
25830a01a9
Коммит
69629a5a94
|
@ -16,274 +16,17 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
"""
|
||||||
|
This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.
|
||||||
|
"""
|
||||||
|
|
||||||
import datetime
|
import warnings
|
||||||
import stat
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import pysftp
|
# pylint: disable=unused-import
|
||||||
|
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook # noqa
|
||||||
|
|
||||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
warnings.warn(
|
||||||
|
"This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp_hook`.",
|
||||||
|
DeprecationWarning,
|
||||||
class SFTPHook(SSHHook):
|
stacklevel=2,
|
||||||
"""
|
)
|
||||||
This hook is inherited from SSH hook. Please refer to SSH hook for the input
|
|
||||||
arguments.
|
|
||||||
|
|
||||||
Interact with SFTP. Aims to be interchangeable with FTPHook.
|
|
||||||
|
|
||||||
:Pitfalls::
|
|
||||||
|
|
||||||
- In contrast with FTPHook describe_directory only returns size, type and
|
|
||||||
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
|
|
||||||
unique.
|
|
||||||
- retrieve_file and store_file only take a local full path and not a
|
|
||||||
buffer.
|
|
||||||
- If no mode is passed to create_directory it will be created with 777
|
|
||||||
permissions.
|
|
||||||
|
|
||||||
Errors that may occur throughout but should be handled downstream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
|
|
||||||
kwargs['ssh_conn_id'] = ftp_conn_id
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.conn = None
|
|
||||||
self.private_key_pass = None
|
|
||||||
|
|
||||||
# Fail for unverified hosts, unless this is explicitly allowed
|
|
||||||
self.no_host_key_check = False
|
|
||||||
|
|
||||||
if self.ssh_conn_id is not None:
|
|
||||||
conn = self.get_connection(self.ssh_conn_id)
|
|
||||||
if conn.extra is not None:
|
|
||||||
extra_options = conn.extra_dejson
|
|
||||||
if 'private_key_pass' in extra_options:
|
|
||||||
self.private_key_pass = extra_options.get('private_key_pass', None)
|
|
||||||
|
|
||||||
# For backward compatibility
|
|
||||||
# TODO: remove in Airflow 2.1
|
|
||||||
import warnings
|
|
||||||
if 'ignore_hostkey_verification' in extra_options:
|
|
||||||
warnings.warn(
|
|
||||||
'Extra option `ignore_hostkey_verification` is deprecated.'
|
|
||||||
'Please use `no_host_key_check` instead.'
|
|
||||||
'This option will be removed in Airflow 2.1',
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self.no_host_key_check = str(
|
|
||||||
extra_options['ignore_hostkey_verification']
|
|
||||||
).lower() == 'true'
|
|
||||||
|
|
||||||
if 'no_host_key_check' in extra_options:
|
|
||||||
self.no_host_key_check = str(
|
|
||||||
extra_options['no_host_key_check']).lower() == 'true'
|
|
||||||
|
|
||||||
if 'private_key' in extra_options:
|
|
||||||
warnings.warn(
|
|
||||||
'Extra option `private_key` is deprecated.'
|
|
||||||
'Please use `key_file` instead.'
|
|
||||||
'This option will be removed in Airflow 2.1',
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
self.key_file = extra_options.get('private_key')
|
|
||||||
|
|
||||||
def get_conn(self) -> pysftp.Connection:
|
|
||||||
"""
|
|
||||||
Returns an SFTP connection object
|
|
||||||
"""
|
|
||||||
if self.conn is None:
|
|
||||||
cnopts = pysftp.CnOpts()
|
|
||||||
if self.no_host_key_check:
|
|
||||||
cnopts.hostkeys = None
|
|
||||||
cnopts.compression = self.compress
|
|
||||||
conn_params = {
|
|
||||||
'host': self.remote_host,
|
|
||||||
'port': self.port,
|
|
||||||
'username': self.username,
|
|
||||||
'cnopts': cnopts
|
|
||||||
}
|
|
||||||
if self.password and self.password.strip():
|
|
||||||
conn_params['password'] = self.password
|
|
||||||
if self.key_file:
|
|
||||||
conn_params['private_key'] = self.key_file
|
|
||||||
if self.private_key_pass:
|
|
||||||
conn_params['private_key_pass'] = self.private_key_pass
|
|
||||||
|
|
||||||
self.conn = pysftp.Connection(**conn_params)
|
|
||||||
return self.conn
|
|
||||||
|
|
||||||
def close_conn(self) -> None:
|
|
||||||
"""
|
|
||||||
Closes the connection. An error will occur if the
|
|
||||||
connection wasnt ever opened.
|
|
||||||
"""
|
|
||||||
conn = self.conn
|
|
||||||
conn.close() # type: ignore
|
|
||||||
self.conn = None
|
|
||||||
|
|
||||||
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Returns a dictionary of {filename: {attributes}} for all files
|
|
||||||
on the remote system (where the MLSD command is supported).
|
|
||||||
|
|
||||||
:param path: full path to the remote directory
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
flist = conn.listdir_attr(path)
|
|
||||||
files = {}
|
|
||||||
for f in flist:
|
|
||||||
modify = datetime.datetime.fromtimestamp(
|
|
||||||
f.st_mtime).strftime('%Y%m%d%H%M%S')
|
|
||||||
files[f.filename] = {
|
|
||||||
'size': f.st_size,
|
|
||||||
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
|
|
||||||
'modify': modify}
|
|
||||||
return files
|
|
||||||
|
|
||||||
def list_directory(self, path: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
Returns a list of files on the remote system.
|
|
||||||
|
|
||||||
:param path: full path to the remote directory to list
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
files = conn.listdir(path)
|
|
||||||
return files
|
|
||||||
|
|
||||||
def create_directory(self, path: str, mode: int = 777) -> None:
|
|
||||||
"""
|
|
||||||
Creates a directory on the remote system.
|
|
||||||
|
|
||||||
:param path: full path to the remote directory to create
|
|
||||||
:type path: str
|
|
||||||
:param mode: int representation of octal mode for directory
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
conn.makedirs(path, mode)
|
|
||||||
|
|
||||||
def delete_directory(self, path: str) -> None:
|
|
||||||
"""
|
|
||||||
Deletes a directory on the remote system.
|
|
||||||
|
|
||||||
:param path: full path to the remote directory to delete
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
conn.rmdir(path)
|
|
||||||
|
|
||||||
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
|
|
||||||
"""
|
|
||||||
Transfers the remote file to a local location.
|
|
||||||
If local_full_path is a string path, the file will be put
|
|
||||||
at that location
|
|
||||||
|
|
||||||
:param remote_full_path: full path to the remote file
|
|
||||||
:type remote_full_path: str
|
|
||||||
:param local_full_path: full path to the local file
|
|
||||||
:type local_full_path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
self.log.info('Retrieving file from FTP: %s', remote_full_path)
|
|
||||||
conn.get(remote_full_path, local_full_path)
|
|
||||||
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
|
|
||||||
|
|
||||||
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
|
|
||||||
"""
|
|
||||||
Transfers a local file to the remote location.
|
|
||||||
If local_full_path_or_buffer is a string path, the file will be read
|
|
||||||
from that location
|
|
||||||
|
|
||||||
:param remote_full_path: full path to the remote file
|
|
||||||
:type remote_full_path: str
|
|
||||||
:param local_full_path: full path to the local file
|
|
||||||
:type local_full_path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
conn.put(local_full_path, remote_full_path)
|
|
||||||
|
|
||||||
def delete_file(self, path: str) -> None:
|
|
||||||
"""
|
|
||||||
Removes a file on the FTP Server
|
|
||||||
|
|
||||||
:param path: full path to the remote file
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
conn.remove(path)
|
|
||||||
|
|
||||||
def get_mod_time(self, path: str) -> str:
|
|
||||||
conn = self.get_conn()
|
|
||||||
ftp_mdtm = conn.stat(path).st_mtime
|
|
||||||
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
|
|
||||||
|
|
||||||
def path_exists(self, path: str) -> bool:
|
|
||||||
"""
|
|
||||||
Returns True if a remote entity exists
|
|
||||||
|
|
||||||
:param path: full path to the remote file or directory
|
|
||||||
:type path: str
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
return conn.exists(path)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
|
|
||||||
"""
|
|
||||||
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
|
|
||||||
|
|
||||||
:param path: path to be checked
|
|
||||||
:type path: str
|
|
||||||
:param prefix: if set path will be checked is starting with prefix
|
|
||||||
:type prefix: str
|
|
||||||
:param delimiter: if set path will be checked is ending with suffix
|
|
||||||
:type delimiter: str
|
|
||||||
:return: bool
|
|
||||||
"""
|
|
||||||
if prefix is not None and not path.startswith(prefix):
|
|
||||||
return False
|
|
||||||
if delimiter is not None and not path.endswith(delimiter):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_tree_map(
|
|
||||||
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
|
|
||||||
) -> Tuple[List[str], List[str], List[str]]:
|
|
||||||
"""
|
|
||||||
Return tuple with recursive lists of files, directories and unknown paths from given path.
|
|
||||||
It is possible to filter results by giving prefix and/or delimiter parameters.
|
|
||||||
|
|
||||||
:param path: path from which tree will be built
|
|
||||||
:type path: str
|
|
||||||
:param prefix: if set paths will be added if start with prefix
|
|
||||||
:type prefix: str
|
|
||||||
:param delimiter: if set paths will be added if end with delimiter
|
|
||||||
:type delimiter: str
|
|
||||||
:return: tuple with list of files, dirs and unknown items
|
|
||||||
:rtype: Tuple[List[str], List[str], List[str]]
|
|
||||||
"""
|
|
||||||
conn = self.get_conn()
|
|
||||||
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
|
|
||||||
|
|
||||||
def append_matching_path_callback(list_):
|
|
||||||
return (
|
|
||||||
lambda item: list_.append(item)
|
|
||||||
if self._is_path_match(item, prefix, delimiter)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.walktree(
|
|
||||||
remotepath=path,
|
|
||||||
fcallback=append_matching_path_callback(files),
|
|
||||||
dcallback=append_matching_path_callback(dirs),
|
|
||||||
ucallback=append_matching_path_callback(unknowns),
|
|
||||||
recurse=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return files, dirs, unknowns
|
|
||||||
|
|
|
@ -16,165 +16,17 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
import os
|
"""
|
||||||
|
This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.
|
||||||
|
"""
|
||||||
|
|
||||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
import warnings
|
||||||
from airflow.exceptions import AirflowException
|
|
||||||
from airflow.models import BaseOperator
|
|
||||||
from airflow.utils.decorators import apply_defaults
|
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from airflow.providers.sftp.operators.sftp_operator import SFTPOperator # noqa
|
||||||
|
|
||||||
class SFTPOperation:
|
warnings.warn(
|
||||||
PUT = 'put'
|
"This module is deprecated. Please use `airflow.providers.sftp.operators.sftp_operator`.",
|
||||||
GET = 'get'
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
class SFTPOperator(BaseOperator):
|
|
||||||
"""
|
|
||||||
SFTPOperator for transferring files from remote host to local or vice a versa.
|
|
||||||
This operator uses ssh_hook to open sftp transport channel that serve as basis
|
|
||||||
for file transfer.
|
|
||||||
|
|
||||||
:param ssh_hook: predefined ssh_hook to use for remote execution.
|
|
||||||
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
|
|
||||||
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
|
|
||||||
:param ssh_conn_id: connection id from airflow Connections.
|
|
||||||
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
|
|
||||||
:type ssh_conn_id: str
|
|
||||||
:param remote_host: remote host to connect (templated)
|
|
||||||
Nullable. If provided, it will replace the `remote_host` which was
|
|
||||||
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
|
|
||||||
:type remote_host: str
|
|
||||||
:param local_filepath: local file path to get or put. (templated)
|
|
||||||
:type local_filepath: str
|
|
||||||
:param remote_filepath: remote file path to get or put. (templated)
|
|
||||||
:type remote_filepath: str
|
|
||||||
:param operation: specify operation 'get' or 'put', defaults to put
|
|
||||||
:type operation: str
|
|
||||||
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
|
|
||||||
:type confirm: bool
|
|
||||||
:param create_intermediate_dirs: create missing intermediate directories when
|
|
||||||
copying from remote to local and vice-versa. Default is False.
|
|
||||||
|
|
||||||
Example: The following task would copy ``file.txt`` to the remote host
|
|
||||||
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
|
|
||||||
don't exist. If the parameter is not passed it would error as the directory
|
|
||||||
does not exist. ::
|
|
||||||
|
|
||||||
put_file = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_conn_id="ssh_default",
|
|
||||||
local_filepath="/tmp/file.txt",
|
|
||||||
remote_filepath="/tmp/tmp1/tmp2/file.txt",
|
|
||||||
operation="put",
|
|
||||||
create_intermediate_dirs=True,
|
|
||||||
dag=dag
|
|
||||||
)
|
|
||||||
|
|
||||||
:type create_intermediate_dirs: bool
|
|
||||||
"""
|
|
||||||
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
|
|
||||||
|
|
||||||
@apply_defaults
|
|
||||||
def __init__(self,
|
|
||||||
ssh_hook=None,
|
|
||||||
ssh_conn_id=None,
|
|
||||||
remote_host=None,
|
|
||||||
local_filepath=None,
|
|
||||||
remote_filepath=None,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
confirm=True,
|
|
||||||
create_intermediate_dirs=False,
|
|
||||||
*args,
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ssh_hook = ssh_hook
|
|
||||||
self.ssh_conn_id = ssh_conn_id
|
|
||||||
self.remote_host = remote_host
|
|
||||||
self.local_filepath = local_filepath
|
|
||||||
self.remote_filepath = remote_filepath
|
|
||||||
self.operation = operation
|
|
||||||
self.confirm = confirm
|
|
||||||
self.create_intermediate_dirs = create_intermediate_dirs
|
|
||||||
if not (self.operation.lower() == SFTPOperation.GET or
|
|
||||||
self.operation.lower() == SFTPOperation.PUT):
|
|
||||||
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
|
|
||||||
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
|
|
||||||
|
|
||||||
def execute(self, context):
|
|
||||||
file_msg = None
|
|
||||||
try:
|
|
||||||
if self.ssh_conn_id:
|
|
||||||
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
|
|
||||||
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
|
|
||||||
else:
|
|
||||||
self.log.info("ssh_hook is not provided or invalid. " +
|
|
||||||
"Trying ssh_conn_id to create SSHHook.")
|
|
||||||
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
|
|
||||||
|
|
||||||
if not self.ssh_hook:
|
|
||||||
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
|
|
||||||
|
|
||||||
if self.remote_host is not None:
|
|
||||||
self.log.info("remote_host is provided explicitly. " +
|
|
||||||
"It will replace the remote_host which was defined " +
|
|
||||||
"in ssh_hook or predefined in connection of ssh_conn_id.")
|
|
||||||
self.ssh_hook.remote_host = self.remote_host
|
|
||||||
|
|
||||||
with self.ssh_hook.get_conn() as ssh_client:
|
|
||||||
sftp_client = ssh_client.open_sftp()
|
|
||||||
if self.operation.lower() == SFTPOperation.GET:
|
|
||||||
local_folder = os.path.dirname(self.local_filepath)
|
|
||||||
if self.create_intermediate_dirs:
|
|
||||||
# Create Intermediate Directories if it doesn't exist
|
|
||||||
try:
|
|
||||||
os.makedirs(local_folder)
|
|
||||||
except OSError:
|
|
||||||
if not os.path.isdir(local_folder):
|
|
||||||
raise
|
|
||||||
file_msg = "from {0} to {1}".format(self.remote_filepath,
|
|
||||||
self.local_filepath)
|
|
||||||
self.log.info("Starting to transfer %s", file_msg)
|
|
||||||
sftp_client.get(self.remote_filepath, self.local_filepath)
|
|
||||||
else:
|
|
||||||
remote_folder = os.path.dirname(self.remote_filepath)
|
|
||||||
if self.create_intermediate_dirs:
|
|
||||||
_make_intermediate_dirs(
|
|
||||||
sftp_client=sftp_client,
|
|
||||||
remote_directory=remote_folder,
|
|
||||||
)
|
|
||||||
file_msg = "from {0} to {1}".format(self.local_filepath,
|
|
||||||
self.remote_filepath)
|
|
||||||
self.log.info("Starting to transfer file %s", file_msg)
|
|
||||||
sftp_client.put(self.local_filepath,
|
|
||||||
self.remote_filepath,
|
|
||||||
confirm=self.confirm)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise AirflowException("Error while transferring {0}, error: {1}"
|
|
||||||
.format(file_msg, str(e)))
|
|
||||||
|
|
||||||
return self.local_filepath
|
|
||||||
|
|
||||||
|
|
||||||
def _make_intermediate_dirs(sftp_client, remote_directory):
|
|
||||||
"""
|
|
||||||
Create all the intermediate directories in a remote host
|
|
||||||
|
|
||||||
:param sftp_client: A Paramiko SFTP client.
|
|
||||||
:param remote_directory: Absolute Path of the directory containing the file
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if remote_directory == '/':
|
|
||||||
sftp_client.chdir('/')
|
|
||||||
return
|
|
||||||
if remote_directory == '':
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
sftp_client.chdir(remote_directory)
|
|
||||||
except OSError:
|
|
||||||
dirname, basename = os.path.split(remote_directory.rstrip('/'))
|
|
||||||
_make_intermediate_dirs(sftp_client, dirname)
|
|
||||||
sftp_client.mkdir(basename)
|
|
||||||
sftp_client.chdir(basename)
|
|
||||||
return
|
|
||||||
|
|
|
@ -16,40 +16,17 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
"""
|
||||||
|
This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.
|
||||||
|
"""
|
||||||
|
|
||||||
from paramiko import SFTP_NO_SUCH_FILE
|
import warnings
|
||||||
|
|
||||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
# pylint: disable=unused-import
|
||||||
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor # noqa
|
||||||
from airflow.utils.decorators import apply_defaults
|
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
class SFTPSensor(BaseSensorOperator):
|
"This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp_sensor`.",
|
||||||
"""
|
DeprecationWarning,
|
||||||
Waits for a file or directory to be present on SFTP.
|
stacklevel=2,
|
||||||
|
)
|
||||||
:param path: Remote file or directory path
|
|
||||||
:type path: str
|
|
||||||
:param sftp_conn_id: The connection to run the sensor against
|
|
||||||
:type sftp_conn_id: str
|
|
||||||
"""
|
|
||||||
template_fields = ('path',)
|
|
||||||
|
|
||||||
@apply_defaults
|
|
||||||
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.path = path
|
|
||||||
self.hook = None
|
|
||||||
self.sftp_conn_id = sftp_conn_id
|
|
||||||
|
|
||||||
def poke(self, context):
|
|
||||||
self.hook = SFTPHook(self.sftp_conn_id)
|
|
||||||
self.log.info('Poking for %s', self.path)
|
|
||||||
try:
|
|
||||||
self.hook.get_mod_time(self.path)
|
|
||||||
except OSError as e:
|
|
||||||
if e.errno != SFTP_NO_SUCH_FILE:
|
|
||||||
raise e
|
|
||||||
return False
|
|
||||||
self.hook.close_conn()
|
|
||||||
return True
|
|
||||||
|
|
|
@ -24,9 +24,9 @@ from tempfile import NamedTemporaryFile
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from airflow import AirflowException
|
from airflow import AirflowException
|
||||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
|
||||||
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
||||||
from airflow.models import BaseOperator
|
from airflow.models import BaseOperator
|
||||||
|
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||||
from airflow.utils.decorators import apply_defaults
|
from airflow.utils.decorators import apply_defaults
|
||||||
|
|
||||||
WILDCARD = "*"
|
WILDCARD = "*"
|
||||||
|
|
|
@ -23,9 +23,9 @@ from tempfile import NamedTemporaryFile
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from airflow import AirflowException
|
from airflow import AirflowException
|
||||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
|
||||||
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
from airflow.gcp.hooks.gcs import GoogleCloudStorageHook
|
||||||
from airflow.models import BaseOperator
|
from airflow.models import BaseOperator
|
||||||
|
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||||
from airflow.utils.decorators import apply_defaults
|
from airflow.utils.decorators import apply_defaults
|
||||||
|
|
||||||
WILDCARD = "*"
|
WILDCARD = "*"
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,297 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
This module contains SFTP hook.
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
import stat
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import pysftp
|
||||||
|
|
||||||
|
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||||
|
|
||||||
|
|
||||||
|
class SFTPHook(SSHHook):
|
||||||
|
"""
|
||||||
|
This hook is inherited from SSH hook. Please refer to SSH hook for the input
|
||||||
|
arguments.
|
||||||
|
|
||||||
|
Interact with SFTP. Aims to be interchangeable with FTPHook.
|
||||||
|
|
||||||
|
:Pitfalls::
|
||||||
|
|
||||||
|
- In contrast with FTPHook describe_directory only returns size, type and
|
||||||
|
modify. It doesn't return unix.owner, unix.mode, perm, unix.group and
|
||||||
|
unique.
|
||||||
|
- retrieve_file and store_file only take a local full path and not a
|
||||||
|
buffer.
|
||||||
|
- If no mode is passed to create_directory it will be created with 777
|
||||||
|
permissions.
|
||||||
|
|
||||||
|
Errors that may occur throughout but should be handled downstream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ftp_conn_id: str = 'sftp_default', *args, **kwargs) -> None:
|
||||||
|
kwargs['ssh_conn_id'] = ftp_conn_id
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self.conn = None
|
||||||
|
self.private_key_pass = None
|
||||||
|
|
||||||
|
# Fail for unverified hosts, unless this is explicitly allowed
|
||||||
|
self.no_host_key_check = False
|
||||||
|
|
||||||
|
if self.ssh_conn_id is not None:
|
||||||
|
conn = self.get_connection(self.ssh_conn_id)
|
||||||
|
if conn.extra is not None:
|
||||||
|
extra_options = conn.extra_dejson
|
||||||
|
if 'private_key_pass' in extra_options:
|
||||||
|
self.private_key_pass = extra_options.get('private_key_pass', None)
|
||||||
|
|
||||||
|
# For backward compatibility
|
||||||
|
# TODO: remove in Airflow 2.1
|
||||||
|
import warnings
|
||||||
|
if 'ignore_hostkey_verification' in extra_options:
|
||||||
|
warnings.warn(
|
||||||
|
'Extra option `ignore_hostkey_verification` is deprecated.'
|
||||||
|
'Please use `no_host_key_check` instead.'
|
||||||
|
'This option will be removed in Airflow 2.1',
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.no_host_key_check = str(
|
||||||
|
extra_options['ignore_hostkey_verification']
|
||||||
|
).lower() == 'true'
|
||||||
|
|
||||||
|
if 'no_host_key_check' in extra_options:
|
||||||
|
self.no_host_key_check = str(
|
||||||
|
extra_options['no_host_key_check']).lower() == 'true'
|
||||||
|
|
||||||
|
if 'private_key' in extra_options:
|
||||||
|
warnings.warn(
|
||||||
|
'Extra option `private_key` is deprecated.'
|
||||||
|
'Please use `key_file` instead.'
|
||||||
|
'This option will be removed in Airflow 2.1',
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.key_file = extra_options.get('private_key')
|
||||||
|
|
||||||
|
def get_conn(self) -> pysftp.Connection:
|
||||||
|
"""
|
||||||
|
Returns an SFTP connection object
|
||||||
|
"""
|
||||||
|
if self.conn is None:
|
||||||
|
cnopts = pysftp.CnOpts()
|
||||||
|
if self.no_host_key_check:
|
||||||
|
cnopts.hostkeys = None
|
||||||
|
cnopts.compression = self.compress
|
||||||
|
conn_params = {
|
||||||
|
'host': self.remote_host,
|
||||||
|
'port': self.port,
|
||||||
|
'username': self.username,
|
||||||
|
'cnopts': cnopts
|
||||||
|
}
|
||||||
|
if self.password and self.password.strip():
|
||||||
|
conn_params['password'] = self.password
|
||||||
|
if self.key_file:
|
||||||
|
conn_params['private_key'] = self.key_file
|
||||||
|
if self.private_key_pass:
|
||||||
|
conn_params['private_key_pass'] = self.private_key_pass
|
||||||
|
|
||||||
|
self.conn = pysftp.Connection(**conn_params)
|
||||||
|
return self.conn
|
||||||
|
|
||||||
|
def close_conn(self) -> None:
|
||||||
|
"""
|
||||||
|
Closes the connection. An error will occur if the
|
||||||
|
connection wasnt ever opened.
|
||||||
|
"""
|
||||||
|
conn = self.conn
|
||||||
|
conn.close() # type: ignore
|
||||||
|
self.conn = None
|
||||||
|
|
||||||
|
def describe_directory(self, path: str) -> Dict[str, Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Returns a dictionary of {filename: {attributes}} for all files
|
||||||
|
on the remote system (where the MLSD command is supported).
|
||||||
|
|
||||||
|
:param path: full path to the remote directory
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
flist = conn.listdir_attr(path)
|
||||||
|
files = {}
|
||||||
|
for f in flist:
|
||||||
|
modify = datetime.datetime.fromtimestamp(
|
||||||
|
f.st_mtime).strftime('%Y%m%d%H%M%S')
|
||||||
|
files[f.filename] = {
|
||||||
|
'size': f.st_size,
|
||||||
|
'type': 'dir' if stat.S_ISDIR(f.st_mode) else 'file',
|
||||||
|
'modify': modify}
|
||||||
|
return files
|
||||||
|
|
||||||
|
def list_directory(self, path: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns a list of files on the remote system.
|
||||||
|
|
||||||
|
:param path: full path to the remote directory to list
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
files = conn.listdir(path)
|
||||||
|
return files
|
||||||
|
|
||||||
|
def create_directory(self, path: str, mode: int = 777) -> None:
|
||||||
|
"""
|
||||||
|
Creates a directory on the remote system.
|
||||||
|
|
||||||
|
:param path: full path to the remote directory to create
|
||||||
|
:type path: str
|
||||||
|
:param mode: int representation of octal mode for directory
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
conn.makedirs(path, mode)
|
||||||
|
|
||||||
|
def delete_directory(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Deletes a directory on the remote system.
|
||||||
|
|
||||||
|
:param path: full path to the remote directory to delete
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
conn.rmdir(path)
|
||||||
|
|
||||||
|
def retrieve_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Transfers the remote file to a local location.
|
||||||
|
If local_full_path is a string path, the file will be put
|
||||||
|
at that location
|
||||||
|
|
||||||
|
:param remote_full_path: full path to the remote file
|
||||||
|
:type remote_full_path: str
|
||||||
|
:param local_full_path: full path to the local file
|
||||||
|
:type local_full_path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
self.log.info('Retrieving file from FTP: %s', remote_full_path)
|
||||||
|
conn.get(remote_full_path, local_full_path)
|
||||||
|
self.log.info('Finished retrieving file from FTP: %s', remote_full_path)
|
||||||
|
|
||||||
|
def store_file(self, remote_full_path: str, local_full_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Transfers a local file to the remote location.
|
||||||
|
If local_full_path_or_buffer is a string path, the file will be read
|
||||||
|
from that location
|
||||||
|
|
||||||
|
:param remote_full_path: full path to the remote file
|
||||||
|
:type remote_full_path: str
|
||||||
|
:param local_full_path: full path to the local file
|
||||||
|
:type local_full_path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
conn.put(local_full_path, remote_full_path)
|
||||||
|
|
||||||
|
def delete_file(self, path: str) -> None:
|
||||||
|
"""
|
||||||
|
Removes a file on the FTP Server
|
||||||
|
|
||||||
|
:param path: full path to the remote file
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
conn.remove(path)
|
||||||
|
|
||||||
|
def get_mod_time(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns modification time.
|
||||||
|
|
||||||
|
:param path: full path to the remote file
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
ftp_mdtm = conn.stat(path).st_mtime
|
||||||
|
return datetime.datetime.fromtimestamp(ftp_mdtm).strftime('%Y%m%d%H%M%S')
|
||||||
|
|
||||||
|
def path_exists(self, path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if a remote entity exists
|
||||||
|
|
||||||
|
:param path: full path to the remote file or directory
|
||||||
|
:type path: str
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
return conn.exists(path)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_path_match(path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if given path starts with prefix (if set) and ends with delimiter (if set).
|
||||||
|
|
||||||
|
:param path: path to be checked
|
||||||
|
:type path: str
|
||||||
|
:param prefix: if set path will be checked is starting with prefix
|
||||||
|
:type prefix: str
|
||||||
|
:param delimiter: if set path will be checked is ending with suffix
|
||||||
|
:type delimiter: str
|
||||||
|
:return: bool
|
||||||
|
"""
|
||||||
|
if prefix is not None and not path.startswith(prefix):
|
||||||
|
return False
|
||||||
|
if delimiter is not None and not path.endswith(delimiter):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_tree_map(
|
||||||
|
self, path: str, prefix: Optional[str] = None, delimiter: Optional[str] = None
|
||||||
|
) -> Tuple[List[str], List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Return tuple with recursive lists of files, directories and unknown paths from given path.
|
||||||
|
It is possible to filter results by giving prefix and/or delimiter parameters.
|
||||||
|
|
||||||
|
:param path: path from which tree will be built
|
||||||
|
:type path: str
|
||||||
|
:param prefix: if set paths will be added if start with prefix
|
||||||
|
:type prefix: str
|
||||||
|
:param delimiter: if set paths will be added if end with delimiter
|
||||||
|
:type delimiter: str
|
||||||
|
:return: tuple with list of files, dirs and unknown items
|
||||||
|
:rtype: Tuple[List[str], List[str], List[str]]
|
||||||
|
"""
|
||||||
|
conn = self.get_conn()
|
||||||
|
files, dirs, unknowns = [], [], [] # type: List[str], List[str], List[str]
|
||||||
|
|
||||||
|
def append_matching_path_callback(list_):
|
||||||
|
return (
|
||||||
|
lambda item: list_.append(item)
|
||||||
|
if self._is_path_match(item, prefix, delimiter)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.walktree(
|
||||||
|
remotepath=path,
|
||||||
|
fcallback=append_matching_path_callback(files),
|
||||||
|
dcallback=append_matching_path_callback(dirs),
|
||||||
|
ucallback=append_matching_path_callback(unknowns),
|
||||||
|
recurse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return files, dirs, unknowns
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,184 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
This module contains SFTP operator.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||||
|
from airflow.exceptions import AirflowException
|
||||||
|
from airflow.models import BaseOperator
|
||||||
|
from airflow.utils.decorators import apply_defaults
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
class SFTPOperation:
|
||||||
|
PUT = 'put'
|
||||||
|
GET = 'get'
|
||||||
|
|
||||||
|
|
||||||
|
class SFTPOperator(BaseOperator):
|
||||||
|
"""
|
||||||
|
SFTPOperator for transferring files from remote host to local or vice a versa.
|
||||||
|
This operator uses ssh_hook to open sftp transport channel that serve as basis
|
||||||
|
for file transfer.
|
||||||
|
|
||||||
|
:param ssh_hook: predefined ssh_hook to use for remote execution.
|
||||||
|
Either `ssh_hook` or `ssh_conn_id` needs to be provided.
|
||||||
|
:type ssh_hook: airflow.contrib.hooks.ssh_hook.SSHHook
|
||||||
|
:param ssh_conn_id: connection id from airflow Connections.
|
||||||
|
`ssh_conn_id` will be ignored if `ssh_hook` is provided.
|
||||||
|
:type ssh_conn_id: str
|
||||||
|
:param remote_host: remote host to connect (templated)
|
||||||
|
Nullable. If provided, it will replace the `remote_host` which was
|
||||||
|
defined in `ssh_hook` or predefined in the connection of `ssh_conn_id`.
|
||||||
|
:type remote_host: str
|
||||||
|
:param local_filepath: local file path to get or put. (templated)
|
||||||
|
:type local_filepath: str
|
||||||
|
:param remote_filepath: remote file path to get or put. (templated)
|
||||||
|
:type remote_filepath: str
|
||||||
|
:param operation: specify operation 'get' or 'put', defaults to put
|
||||||
|
:type operation: str
|
||||||
|
:param confirm: specify if the SFTP operation should be confirmed, defaults to True
|
||||||
|
:type confirm: bool
|
||||||
|
:param create_intermediate_dirs: create missing intermediate directories when
|
||||||
|
copying from remote to local and vice-versa. Default is False.
|
||||||
|
|
||||||
|
Example: The following task would copy ``file.txt`` to the remote host
|
||||||
|
at ``/tmp/tmp1/tmp2/`` while creating ``tmp``,``tmp1`` and ``tmp2`` if they
|
||||||
|
don't exist. If the parameter is not passed it would error as the directory
|
||||||
|
does not exist. ::
|
||||||
|
|
||||||
|
put_file = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_conn_id="ssh_default",
|
||||||
|
local_filepath="/tmp/file.txt",
|
||||||
|
remote_filepath="/tmp/tmp1/tmp2/file.txt",
|
||||||
|
operation="put",
|
||||||
|
create_intermediate_dirs=True,
|
||||||
|
dag=dag
|
||||||
|
)
|
||||||
|
|
||||||
|
:type create_intermediate_dirs: bool
|
||||||
|
"""
|
||||||
|
template_fields = ('local_filepath', 'remote_filepath', 'remote_host')
|
||||||
|
|
||||||
|
@apply_defaults
|
||||||
|
def __init__(self,
|
||||||
|
ssh_hook=None,
|
||||||
|
ssh_conn_id=None,
|
||||||
|
remote_host=None,
|
||||||
|
local_filepath=None,
|
||||||
|
remote_filepath=None,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
confirm=True,
|
||||||
|
create_intermediate_dirs=False,
|
||||||
|
*args,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ssh_hook = ssh_hook
|
||||||
|
self.ssh_conn_id = ssh_conn_id
|
||||||
|
self.remote_host = remote_host
|
||||||
|
self.local_filepath = local_filepath
|
||||||
|
self.remote_filepath = remote_filepath
|
||||||
|
self.operation = operation
|
||||||
|
self.confirm = confirm
|
||||||
|
self.create_intermediate_dirs = create_intermediate_dirs
|
||||||
|
if not (self.operation.lower() == SFTPOperation.GET or
|
||||||
|
self.operation.lower() == SFTPOperation.PUT):
|
||||||
|
raise TypeError("unsupported operation value {0}, expected {1} or {2}"
|
||||||
|
.format(self.operation, SFTPOperation.GET, SFTPOperation.PUT))
|
||||||
|
|
||||||
|
def execute(self, context):
|
||||||
|
file_msg = None
|
||||||
|
try:
|
||||||
|
if self.ssh_conn_id:
|
||||||
|
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
|
||||||
|
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
|
||||||
|
else:
|
||||||
|
self.log.info("ssh_hook is not provided or invalid. "
|
||||||
|
"Trying ssh_conn_id to create SSHHook.")
|
||||||
|
self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)
|
||||||
|
|
||||||
|
if not self.ssh_hook:
|
||||||
|
raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")
|
||||||
|
|
||||||
|
if self.remote_host is not None:
|
||||||
|
self.log.info("remote_host is provided explicitly. "
|
||||||
|
"It will replace the remote_host which was defined "
|
||||||
|
"in ssh_hook or predefined in connection of ssh_conn_id.")
|
||||||
|
self.ssh_hook.remote_host = self.remote_host
|
||||||
|
|
||||||
|
with self.ssh_hook.get_conn() as ssh_client:
|
||||||
|
sftp_client = ssh_client.open_sftp()
|
||||||
|
if self.operation.lower() == SFTPOperation.GET:
|
||||||
|
local_folder = os.path.dirname(self.local_filepath)
|
||||||
|
if self.create_intermediate_dirs:
|
||||||
|
# Create Intermediate Directories if it doesn't exist
|
||||||
|
try:
|
||||||
|
os.makedirs(local_folder)
|
||||||
|
except OSError:
|
||||||
|
if not os.path.isdir(local_folder):
|
||||||
|
raise
|
||||||
|
file_msg = "from {0} to {1}".format(self.remote_filepath,
|
||||||
|
self.local_filepath)
|
||||||
|
self.log.info("Starting to transfer %s", file_msg)
|
||||||
|
sftp_client.get(self.remote_filepath, self.local_filepath)
|
||||||
|
else:
|
||||||
|
remote_folder = os.path.dirname(self.remote_filepath)
|
||||||
|
if self.create_intermediate_dirs:
|
||||||
|
_make_intermediate_dirs(
|
||||||
|
sftp_client=sftp_client,
|
||||||
|
remote_directory=remote_folder,
|
||||||
|
)
|
||||||
|
file_msg = "from {0} to {1}".format(self.local_filepath,
|
||||||
|
self.remote_filepath)
|
||||||
|
self.log.info("Starting to transfer file %s", file_msg)
|
||||||
|
sftp_client.put(self.local_filepath,
|
||||||
|
self.remote_filepath,
|
||||||
|
confirm=self.confirm)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise AirflowException("Error while transferring {0}, error: {1}"
|
||||||
|
.format(file_msg, str(e)))
|
||||||
|
|
||||||
|
return self.local_filepath
|
||||||
|
|
||||||
|
|
||||||
|
def _make_intermediate_dirs(sftp_client, remote_directory):
|
||||||
|
"""
|
||||||
|
Create all the intermediate directories in a remote host
|
||||||
|
|
||||||
|
:param sftp_client: A Paramiko SFTP client.
|
||||||
|
:param remote_directory: Absolute Path of the directory containing the file
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if remote_directory == '/':
|
||||||
|
sftp_client.chdir('/')
|
||||||
|
return
|
||||||
|
if remote_directory == '':
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
sftp_client.chdir(remote_directory)
|
||||||
|
except OSError:
|
||||||
|
dirname, basename = os.path.split(remote_directory.rstrip('/'))
|
||||||
|
_make_intermediate_dirs(sftp_client, dirname)
|
||||||
|
sftp_client.mkdir(basename)
|
||||||
|
sftp_client.chdir(basename)
|
||||||
|
return
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,57 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
This module contains SFTP sensor.
|
||||||
|
"""
|
||||||
|
from paramiko import SFTP_NO_SUCH_FILE
|
||||||
|
|
||||||
|
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||||
|
from airflow.sensors.base_sensor_operator import BaseSensorOperator
|
||||||
|
from airflow.utils.decorators import apply_defaults
|
||||||
|
|
||||||
|
|
||||||
|
class SFTPSensor(BaseSensorOperator):
|
||||||
|
"""
|
||||||
|
Waits for a file or directory to be present on SFTP.
|
||||||
|
|
||||||
|
:param path: Remote file or directory path
|
||||||
|
:type path: str
|
||||||
|
:param sftp_conn_id: The connection to run the sensor against
|
||||||
|
:type sftp_conn_id: str
|
||||||
|
"""
|
||||||
|
template_fields = ('path',)
|
||||||
|
|
||||||
|
@apply_defaults
|
||||||
|
def __init__(self, path, sftp_conn_id='sftp_default', *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.path = path
|
||||||
|
self.hook = None
|
||||||
|
self.sftp_conn_id = sftp_conn_id
|
||||||
|
|
||||||
|
def poke(self, context):
|
||||||
|
self.hook = SFTPHook(self.sftp_conn_id)
|
||||||
|
self.log.info('Poking for %s', self.path)
|
||||||
|
try:
|
||||||
|
self.hook.get_mod_time(self.path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != SFTP_NO_SUCH_FILE:
|
||||||
|
raise e
|
||||||
|
return False
|
||||||
|
self.hook.close_conn()
|
||||||
|
return True
|
|
@ -76,6 +76,8 @@ All operators are in the following packages:
|
||||||
|
|
||||||
airflow/providers/amazon/aws/sensors/index
|
airflow/providers/amazon/aws/sensors/index
|
||||||
|
|
||||||
|
airflow/providers/apache/cassandra/sensors/index
|
||||||
|
|
||||||
airflow/providers/google/cloud/operators/index
|
airflow/providers/google/cloud/operators/index
|
||||||
|
|
||||||
airflow/providers/google/cloud/sensors/index
|
airflow/providers/google/cloud/sensors/index
|
||||||
|
@ -84,7 +86,9 @@ All operators are in the following packages:
|
||||||
|
|
||||||
airflow/providers/google/marketing_platform/sensors/index
|
airflow/providers/google/marketing_platform/sensors/index
|
||||||
|
|
||||||
airflow/providers/apache/cassandra/sensors/index
|
airflow/providers/sftp/operators/index
|
||||||
|
|
||||||
|
airflow/providers/sftp/sensors/index
|
||||||
|
|
||||||
Hooks
|
Hooks
|
||||||
-----
|
-----
|
||||||
|
@ -117,6 +121,8 @@ All hooks are in the following packages:
|
||||||
|
|
||||||
airflow/providers/apache/cassandra/hooks/index
|
airflow/providers/apache/cassandra/hooks/index
|
||||||
|
|
||||||
|
airflow/providers/sftp/hooks/index
|
||||||
|
|
||||||
Executors
|
Executors
|
||||||
---------
|
---------
|
||||||
Executors are the mechanism by which task instances get run. All executors are
|
Executors are the mechanism by which task instances get run. All executors are
|
||||||
|
|
|
@ -233,6 +233,7 @@ exclude_patterns = [
|
||||||
'_api/airflow/providers/amazon/aws/example_dags',
|
'_api/airflow/providers/amazon/aws/example_dags',
|
||||||
'_api/airflow/providers/apache/index.rst',
|
'_api/airflow/providers/apache/index.rst',
|
||||||
'_api/airflow/providers/apache/cassandra/index.rst',
|
'_api/airflow/providers/apache/cassandra/index.rst',
|
||||||
|
'_api/airflow/providers/sftp/index.rst',
|
||||||
'_api/enums/index.rst',
|
'_api/enums/index.rst',
|
||||||
'_api/json_schema/index.rst',
|
'_api/json_schema/index.rst',
|
||||||
'_api/base_serialization/index.rst',
|
'_api/base_serialization/index.rst',
|
||||||
|
|
|
@ -1246,9 +1246,9 @@ communication protocols or interface.
|
||||||
|
|
||||||
* - `SSH File Transfer Protocol (SFTP) <https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/>`__
|
* - `SSH File Transfer Protocol (SFTP) <https://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/>`__
|
||||||
-
|
-
|
||||||
- :mod:`airflow.contrib.hooks.sftp_hook`
|
- :mod:`airflow.providers.sftp.hooks.sftp_hook`
|
||||||
- :mod:`airflow.contrib.operators.sftp_operator`
|
- :mod:`airflow.providers.sftp.operators.sftp_operator`
|
||||||
- :mod:`airflow.contrib.sensors.sftp_sensor`
|
- :mod:`airflow.providers.sftp.sensors.sftp_sensor`
|
||||||
|
|
||||||
* - `Secure Shell (SSH) <https://tools.ietf.org/html/rfc4251>`__
|
* - `Secure Shell (SSH) <https://tools.ietf.org/html/rfc4251>`__
|
||||||
-
|
-
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
./airflow/contrib/hooks/qubole_check_hook.py
|
./airflow/contrib/hooks/qubole_check_hook.py
|
||||||
./airflow/contrib/hooks/sagemaker_hook.py
|
./airflow/contrib/hooks/sagemaker_hook.py
|
||||||
./airflow/contrib/hooks/segment_hook.py
|
./airflow/contrib/hooks/segment_hook.py
|
||||||
./airflow/contrib/hooks/sftp_hook.py
|
|
||||||
./airflow/contrib/hooks/slack_webhook_hook.py
|
./airflow/contrib/hooks/slack_webhook_hook.py
|
||||||
./airflow/contrib/hooks/snowflake_hook.py
|
./airflow/contrib/hooks/snowflake_hook.py
|
||||||
./airflow/contrib/hooks/spark_jdbc_hook.py
|
./airflow/contrib/hooks/spark_jdbc_hook.py
|
||||||
|
@ -65,7 +64,6 @@
|
||||||
./airflow/contrib/operators/sagemaker_transform_operator.py
|
./airflow/contrib/operators/sagemaker_transform_operator.py
|
||||||
./airflow/contrib/operators/sagemaker_tuning_operator.py
|
./airflow/contrib/operators/sagemaker_tuning_operator.py
|
||||||
./airflow/contrib/operators/segment_track_event_operator.py
|
./airflow/contrib/operators/segment_track_event_operator.py
|
||||||
./airflow/contrib/operators/sftp_operator.py
|
|
||||||
./airflow/contrib/operators/sftp_to_s3_operator.py
|
./airflow/contrib/operators/sftp_to_s3_operator.py
|
||||||
./airflow/contrib/operators/slack_webhook_operator.py
|
./airflow/contrib/operators/slack_webhook_operator.py
|
||||||
./airflow/contrib/operators/snowflake_operator.py
|
./airflow/contrib/operators/snowflake_operator.py
|
||||||
|
@ -100,7 +98,6 @@
|
||||||
./airflow/contrib/sensors/sagemaker_training_sensor.py
|
./airflow/contrib/sensors/sagemaker_training_sensor.py
|
||||||
./airflow/contrib/sensors/sagemaker_transform_sensor.py
|
./airflow/contrib/sensors/sagemaker_transform_sensor.py
|
||||||
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
|
./airflow/contrib/sensors/sagemaker_tuning_sensor.py
|
||||||
./airflow/contrib/sensors/sftp_sensor.py
|
|
||||||
./airflow/contrib/sensors/wasb_sensor.py
|
./airflow/contrib/sensors/wasb_sensor.py
|
||||||
./airflow/contrib/sensors/weekday_sensor.py
|
./airflow/contrib/sensors/weekday_sensor.py
|
||||||
./airflow/hooks/dbapi_hook.py
|
./airflow/hooks/dbapi_hook.py
|
||||||
|
|
|
@ -17,231 +17,17 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import os
|
import importlib
|
||||||
import shutil
|
from unittest import TestCase
|
||||||
import unittest
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pysftp
|
OLD_PATH = "airflow.contrib.hooks.sftp_hook"
|
||||||
from parameterized import parameterized
|
NEW_PATH = "airflow.providers.sftp.hooks.sftp_hook"
|
||||||
|
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||||
from airflow.contrib.hooks.sftp_hook import SFTPHook
|
|
||||||
from airflow.models import Connection
|
|
||||||
from airflow.utils.db import provide_session
|
|
||||||
|
|
||||||
TMP_PATH = '/tmp'
|
|
||||||
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
|
|
||||||
SUB_DIR = "sub_dir"
|
|
||||||
TMP_FILE_FOR_TESTS = 'test_file.txt'
|
|
||||||
|
|
||||||
SFTP_CONNECTION_USER = "root"
|
|
||||||
|
|
||||||
|
|
||||||
class TestSFTPHook(unittest.TestCase):
|
class TestMovingSFTPHookToCore(TestCase):
|
||||||
|
def test_move_sftp_hook_to_core(self):
|
||||||
@provide_session
|
with self.assertWarns(DeprecationWarning) as warn:
|
||||||
def update_connection(self, login, session=None):
|
# Reload to see deprecation warning each time
|
||||||
connection = (session.query(Connection).
|
importlib.import_module(OLD_PATH)
|
||||||
filter(Connection.conn_id == "sftp_default")
|
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||||
.first())
|
|
||||||
old_login = connection.login
|
|
||||||
connection.login = login
|
|
||||||
session.commit()
|
|
||||||
return old_login
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
|
|
||||||
self.hook = SFTPHook()
|
|
||||||
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
|
|
||||||
|
|
||||||
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
|
|
||||||
file.write('Test file')
|
|
||||||
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
|
|
||||||
file.write('Test file')
|
|
||||||
|
|
||||||
def test_get_conn(self):
|
|
||||||
output = self.hook.get_conn()
|
|
||||||
self.assertEqual(type(output), pysftp.Connection)
|
|
||||||
|
|
||||||
def test_close_conn(self):
|
|
||||||
self.hook.conn = self.hook.get_conn()
|
|
||||||
self.assertTrue(self.hook.conn is not None)
|
|
||||||
self.hook.close_conn()
|
|
||||||
self.assertTrue(self.hook.conn is None)
|
|
||||||
|
|
||||||
def test_describe_directory(self):
|
|
||||||
output = self.hook.describe_directory(TMP_PATH)
|
|
||||||
self.assertTrue(TMP_DIR_FOR_TESTS in output)
|
|
||||||
|
|
||||||
def test_list_directory(self):
|
|
||||||
output = self.hook.list_directory(
|
|
||||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertEqual(output, [SUB_DIR])
|
|
||||||
|
|
||||||
def test_create_and_delete_directory(self):
|
|
||||||
new_dir_name = 'new_dir'
|
|
||||||
self.hook.create_directory(os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
|
||||||
output = self.hook.describe_directory(
|
|
||||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertTrue(new_dir_name in output)
|
|
||||||
self.hook.delete_directory(os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
|
||||||
output = self.hook.describe_directory(
|
|
||||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertTrue(new_dir_name not in output)
|
|
||||||
|
|
||||||
def test_create_and_delete_directories(self):
|
|
||||||
base_dir = "base_dir"
|
|
||||||
sub_dir = "sub_dir"
|
|
||||||
new_dir_path = os.path.join(base_dir, sub_dir)
|
|
||||||
self.hook.create_directory(os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
|
||||||
output = self.hook.describe_directory(
|
|
||||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertTrue(base_dir in output)
|
|
||||||
output = self.hook.describe_directory(
|
|
||||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
|
||||||
self.assertTrue(sub_dir in output)
|
|
||||||
self.hook.delete_directory(os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
|
||||||
self.hook.delete_directory(os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
|
||||||
output = self.hook.describe_directory(
|
|
||||||
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertTrue(new_dir_path not in output)
|
|
||||||
self.assertTrue(base_dir not in output)
|
|
||||||
|
|
||||||
def test_store_retrieve_and_delete_file(self):
|
|
||||||
self.hook.store_file(
|
|
||||||
remote_full_path=os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
|
||||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
|
||||||
)
|
|
||||||
output = self.hook.list_directory(
|
|
||||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
|
|
||||||
retrieved_file_name = 'retrieved.txt'
|
|
||||||
self.hook.retrieve_file(
|
|
||||||
remote_full_path=os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
|
||||||
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
|
|
||||||
)
|
|
||||||
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
|
|
||||||
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
|
|
||||||
self.hook.delete_file(path=os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
|
||||||
output = self.hook.list_directory(
|
|
||||||
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
self.assertEqual(output, [SUB_DIR])
|
|
||||||
|
|
||||||
def test_get_mod_time(self):
|
|
||||||
self.hook.store_file(
|
|
||||||
remote_full_path=os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
|
||||||
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
|
||||||
)
|
|
||||||
output = self.hook.get_mod_time(path=os.path.join(
|
|
||||||
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
|
||||||
self.assertEqual(len(output), 14)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_default(self, get_connection):
|
|
||||||
connection = Connection(login='login', host='host')
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, False)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_enabled(self, get_connection):
|
|
||||||
connection = Connection(
|
|
||||||
login='login', host='host',
|
|
||||||
extra='{"no_host_key_check": true}')
|
|
||||||
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, True)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_disabled(self, get_connection):
|
|
||||||
connection = Connection(
|
|
||||||
login='login', host='host',
|
|
||||||
extra='{"no_host_key_check": false}')
|
|
||||||
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, False)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
|
|
||||||
connection = Connection(
|
|
||||||
login='login', host='host',
|
|
||||||
extra='{"no_host_key_check": "foo"}')
|
|
||||||
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, False)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_ignore(self, get_connection):
|
|
||||||
connection = Connection(
|
|
||||||
login='login', host='host',
|
|
||||||
extra='{"ignore_hostkey_verification": true}')
|
|
||||||
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, True)
|
|
||||||
|
|
||||||
@mock.patch('airflow.contrib.hooks.sftp_hook.SFTPHook.get_connection')
|
|
||||||
def test_no_host_key_check_no_ignore(self, get_connection):
|
|
||||||
connection = Connection(
|
|
||||||
login='login', host='host',
|
|
||||||
extra='{"ignore_hostkey_verification": false}')
|
|
||||||
|
|
||||||
get_connection.return_value = connection
|
|
||||||
hook = SFTPHook()
|
|
||||||
self.assertEqual(hook.no_host_key_check, False)
|
|
||||||
|
|
||||||
@parameterized.expand([
|
|
||||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
|
|
||||||
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
|
|
||||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
|
|
||||||
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
|
|
||||||
])
|
|
||||||
def test_path_exists(self, path, exists):
|
|
||||||
result = self.hook.path_exists(path)
|
|
||||||
self.assertEqual(result, exists)
|
|
||||||
|
|
||||||
@parameterized.expand([
|
|
||||||
("test/path/file.bin", None, None, True),
|
|
||||||
("test/path/file.bin", "test", None, True),
|
|
||||||
("test/path/file.bin", "test/", None, True),
|
|
||||||
("test/path/file.bin", None, "bin", True),
|
|
||||||
("test/path/file.bin", "test", "bin", True),
|
|
||||||
("test/path/file.bin", "test/", "file.bin", True),
|
|
||||||
("test/path/file.bin", None, "file.bin", True),
|
|
||||||
("test/path/file.bin", "diff", None, False),
|
|
||||||
("test/path/file.bin", "test//", None, False),
|
|
||||||
("test/path/file.bin", None, ".txt", False),
|
|
||||||
("test/path/file.bin", "diff", ".txt", False),
|
|
||||||
])
|
|
||||||
def test_path_match(self, path, prefix, delimiter, match):
|
|
||||||
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
|
|
||||||
self.assertEqual(result, match)
|
|
||||||
|
|
||||||
def test_get_tree_map(self):
|
|
||||||
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
files, dirs, unknowns = tree_map
|
|
||||||
|
|
||||||
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
|
|
||||||
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
|
|
||||||
self.assertEqual(unknowns, [])
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
|
||||||
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
|
|
||||||
self.update_connection(self.old_login)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
@ -17,436 +17,17 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import os
|
import importlib
|
||||||
import unittest
|
from unittest import TestCase
|
||||||
from base64 import b64encode
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
from airflow import AirflowException
|
OLD_PATH = "airflow.contrib.operators.sftp_operator"
|
||||||
from airflow.contrib.operators.sftp_operator import SFTPOperation, SFTPOperator
|
NEW_PATH = "airflow.providers.sftp.operators.sftp_operator"
|
||||||
from airflow.contrib.operators.ssh_operator import SSHOperator
|
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||||
from airflow.models import DAG, TaskInstance
|
|
||||||
from airflow.utils import timezone
|
|
||||||
from airflow.utils.timezone import datetime
|
|
||||||
from tests.test_utils.config import conf_vars
|
|
||||||
|
|
||||||
TEST_DAG_ID = 'unit_tests_sftp_op'
|
|
||||||
DEFAULT_DATE = datetime(2017, 1, 1)
|
|
||||||
TEST_CONN_ID = "conn_id_for_testing"
|
|
||||||
|
|
||||||
|
|
||||||
class TestSFTPOperator(unittest.TestCase):
|
class TestMovingSFTPHookToCore(TestCase):
|
||||||
def setUp(self):
|
def test_move_sftp_operator_to_core(self):
|
||||||
from airflow.contrib.hooks.ssh_hook import SSHHook
|
with self.assertWarns(DeprecationWarning) as warn:
|
||||||
|
# Reload to see deprecation warning each time
|
||||||
hook = SSHHook(ssh_conn_id='ssh_default')
|
importlib.import_module(OLD_PATH)
|
||||||
hook.no_host_key_check = True
|
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||||
args = {
|
|
||||||
'owner': 'airflow',
|
|
||||||
'start_date': DEFAULT_DATE,
|
|
||||||
}
|
|
||||||
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
|
|
||||||
dag.schedule_interval = '@once'
|
|
||||||
self.hook = hook
|
|
||||||
self.dag = dag
|
|
||||||
self.test_dir = "/tmp"
|
|
||||||
self.test_local_dir = "/tmp/tmp2"
|
|
||||||
self.test_remote_dir = "/tmp/tmp1"
|
|
||||||
self.test_local_filename = 'test_local_file'
|
|
||||||
self.test_remote_filename = 'test_remote_file'
|
|
||||||
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
|
|
||||||
self.test_local_filename)
|
|
||||||
# Local Filepath with Intermediate Directory
|
|
||||||
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
|
|
||||||
self.test_local_filename)
|
|
||||||
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
|
|
||||||
self.test_remote_filename)
|
|
||||||
# Remote Filepath with Intermediate Directory
|
|
||||||
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
|
|
||||||
self.test_remote_filename)
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_pickle_file_transfer_put(self):
|
|
||||||
test_local_file_content = \
|
|
||||||
b"This is local file content \n which is multiline " \
|
|
||||||
b"continuing....with other character\nanother line here \n this is last line"
|
|
||||||
# create a test file locally
|
|
||||||
with open(self.test_local_filepath, 'wb') as file:
|
|
||||||
file.write(test_local_file_content)
|
|
||||||
|
|
||||||
# put test file to remote
|
|
||||||
put_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
create_intermediate_dirs=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(put_test_task)
|
|
||||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# check the remote file content
|
|
||||||
check_file_task = SSHOperator(
|
|
||||||
task_id="test_check_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="cat {0}".format(self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(check_file_task)
|
|
||||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti3.run()
|
|
||||||
self.assertEqual(
|
|
||||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
|
||||||
test_local_file_content)
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_file_transfer_no_intermediate_dir_error_put(self):
|
|
||||||
test_local_file_content = \
|
|
||||||
b"This is local file content \n which is multiline " \
|
|
||||||
b"continuing....with other character\nanother line here \n this is last line"
|
|
||||||
# create a test file locally
|
|
||||||
with open(self.test_local_filepath, 'wb') as file:
|
|
||||||
file.write(test_local_file_content)
|
|
||||||
|
|
||||||
# Try to put test file to remote
|
|
||||||
# This should raise an error with "No such file" as the directory
|
|
||||||
# does not exist
|
|
||||||
with self.assertRaises(Exception) as error:
|
|
||||||
put_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath_int_dir,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
create_intermediate_dirs=False,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(put_test_task)
|
|
||||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
self.assertIn('No such file', str(error.exception))
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_file_transfer_with_intermediate_dir_put(self):
|
|
||||||
test_local_file_content = \
|
|
||||||
b"This is local file content \n which is multiline " \
|
|
||||||
b"continuing....with other character\nanother line here \n this is last line"
|
|
||||||
# create a test file locally
|
|
||||||
with open(self.test_local_filepath, 'wb') as file:
|
|
||||||
file.write(test_local_file_content)
|
|
||||||
|
|
||||||
# put test file to remote
|
|
||||||
put_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath_int_dir,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
create_intermediate_dirs=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(put_test_task)
|
|
||||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# check the remote file content
|
|
||||||
check_file_task = SSHOperator(
|
|
||||||
task_id="test_check_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="cat {0}".format(self.test_remote_filepath_int_dir),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(check_file_task)
|
|
||||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti3.run()
|
|
||||||
self.assertEqual(
|
|
||||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
|
||||||
test_local_file_content)
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
|
||||||
def test_json_file_transfer_put(self):
|
|
||||||
test_local_file_content = \
|
|
||||||
b"This is local file content \n which is multiline " \
|
|
||||||
b"continuing....with other character\nanother line here \n this is last line"
|
|
||||||
# create a test file locally
|
|
||||||
with open(self.test_local_filepath, 'wb') as file:
|
|
||||||
file.write(test_local_file_content)
|
|
||||||
|
|
||||||
# put test file to remote
|
|
||||||
put_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(put_test_task)
|
|
||||||
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# check the remote file content
|
|
||||||
check_file_task = SSHOperator(
|
|
||||||
task_id="test_check_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="cat {0}".format(self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(check_file_task)
|
|
||||||
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti3.run()
|
|
||||||
self.assertEqual(
|
|
||||||
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
|
||||||
b64encode(test_local_file_content).decode('utf-8'))
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_pickle_file_transfer_get(self):
|
|
||||||
test_remote_file_content = \
|
|
||||||
"This is remote file content \n which is also multiline " \
|
|
||||||
"another line here \n this is last line. EOF"
|
|
||||||
|
|
||||||
# create a test file remotely
|
|
||||||
create_file_task = SSHOperator(
|
|
||||||
task_id="test_create_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
|
||||||
self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(create_file_task)
|
|
||||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti1.run()
|
|
||||||
|
|
||||||
# get remote file to local
|
|
||||||
get_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.GET,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(get_test_task)
|
|
||||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# test the received content
|
|
||||||
content_received = None
|
|
||||||
with open(self.test_local_filepath, 'r') as file:
|
|
||||||
content_received = file.read()
|
|
||||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
|
||||||
def test_json_file_transfer_get(self):
|
|
||||||
test_remote_file_content = \
|
|
||||||
"This is remote file content \n which is also multiline " \
|
|
||||||
"another line here \n this is last line. EOF"
|
|
||||||
|
|
||||||
# create a test file remotely
|
|
||||||
create_file_task = SSHOperator(
|
|
||||||
task_id="test_create_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
|
||||||
self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(create_file_task)
|
|
||||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti1.run()
|
|
||||||
|
|
||||||
# get remote file to local
|
|
||||||
get_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.GET,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(get_test_task)
|
|
||||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# test the received content
|
|
||||||
content_received = None
|
|
||||||
with open(self.test_local_filepath, 'r') as file:
|
|
||||||
content_received = file.read()
|
|
||||||
self.assertEqual(content_received.strip(),
|
|
||||||
test_remote_file_content.encode('utf-8').decode('utf-8'))
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_file_transfer_no_intermediate_dir_error_get(self):
|
|
||||||
test_remote_file_content = \
|
|
||||||
"This is remote file content \n which is also multiline " \
|
|
||||||
"another line here \n this is last line. EOF"
|
|
||||||
|
|
||||||
# create a test file remotely
|
|
||||||
create_file_task = SSHOperator(
|
|
||||||
task_id="test_create_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
|
||||||
self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(create_file_task)
|
|
||||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti1.run()
|
|
||||||
|
|
||||||
# Try to GET test file from remote
|
|
||||||
# This should raise an error with "No such file" as the directory
|
|
||||||
# does not exist
|
|
||||||
with self.assertRaises(Exception) as error:
|
|
||||||
get_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath_int_dir,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.GET,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(get_test_task)
|
|
||||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
self.assertIn('No such file', str(error.exception))
|
|
||||||
|
|
||||||
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
|
||||||
def test_file_transfer_with_intermediate_dir_error_get(self):
|
|
||||||
test_remote_file_content = \
|
|
||||||
"This is remote file content \n which is also multiline " \
|
|
||||||
"another line here \n this is last line. EOF"
|
|
||||||
|
|
||||||
# create a test file remotely
|
|
||||||
create_file_task = SSHOperator(
|
|
||||||
task_id="test_create_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="echo '{0}' > {1}".format(test_remote_file_content,
|
|
||||||
self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(create_file_task)
|
|
||||||
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti1.run()
|
|
||||||
|
|
||||||
# get remote file to local
|
|
||||||
get_test_task = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
local_filepath=self.test_local_filepath_int_dir,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.GET,
|
|
||||||
create_intermediate_dirs=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(get_test_task)
|
|
||||||
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
|
||||||
ti2.run()
|
|
||||||
|
|
||||||
# test the received content
|
|
||||||
content_received = None
|
|
||||||
with open(self.test_local_filepath_int_dir, 'r') as file:
|
|
||||||
content_received = file.read()
|
|
||||||
self.assertEqual(content_received.strip(), test_remote_file_content)
|
|
||||||
|
|
||||||
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
|
|
||||||
def test_arg_checking(self):
|
|
||||||
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
|
|
||||||
with self.assertRaisesRegex(AirflowException,
|
|
||||||
"Cannot operate without ssh_hook or ssh_conn_id."):
|
|
||||||
task_0 = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
task_0.execute(None)
|
|
||||||
|
|
||||||
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
|
|
||||||
task_1 = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
|
|
||||||
ssh_conn_id=TEST_CONN_ID,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
task_1.execute(None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
|
||||||
|
|
||||||
task_2 = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
task_2.execute(None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
|
||||||
|
|
||||||
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
|
|
||||||
task_3 = SFTPOperator(
|
|
||||||
task_id="test_sftp",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
ssh_conn_id=TEST_CONN_ID,
|
|
||||||
local_filepath=self.test_local_filepath,
|
|
||||||
remote_filepath=self.test_remote_filepath,
|
|
||||||
operation=SFTPOperation.PUT,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
task_3.execute(None)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
|
|
||||||
|
|
||||||
def delete_local_resource(self):
|
|
||||||
if os.path.exists(self.test_local_filepath):
|
|
||||||
os.remove(self.test_local_filepath)
|
|
||||||
if os.path.exists(self.test_local_filepath_int_dir):
|
|
||||||
os.remove(self.test_local_filepath_int_dir)
|
|
||||||
if os.path.exists(self.test_local_dir):
|
|
||||||
os.rmdir(self.test_local_dir)
|
|
||||||
|
|
||||||
def delete_remote_resource(self):
|
|
||||||
if os.path.exists(self.test_remote_filepath):
|
|
||||||
# check the remote file content
|
|
||||||
remove_file_task = SSHOperator(
|
|
||||||
task_id="test_check_file",
|
|
||||||
ssh_hook=self.hook,
|
|
||||||
command="rm {0}".format(self.test_remote_filepath),
|
|
||||||
do_xcom_push=True,
|
|
||||||
dag=self.dag
|
|
||||||
)
|
|
||||||
self.assertIsNotNone(remove_file_task)
|
|
||||||
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
|
|
||||||
ti3.run()
|
|
||||||
if os.path.exists(self.test_remote_filepath_int_dir):
|
|
||||||
os.remove(self.test_remote_filepath_int_dir)
|
|
||||||
if os.path.exists(self.test_remote_dir):
|
|
||||||
os.rmdir(self.test_remote_dir)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.delete_local_resource()
|
|
||||||
self.delete_remote_resource()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
@ -17,61 +17,17 @@
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
import unittest
|
import importlib
|
||||||
from unittest.mock import patch
|
from unittest import TestCase
|
||||||
|
|
||||||
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
|
OLD_PATH = "airflow.contrib.sensors.sftp_sensor"
|
||||||
|
NEW_PATH = "airflow.providers.sftp.sensors.sftp_sensor"
|
||||||
from airflow.contrib.sensors.sftp_sensor import SFTPSensor
|
WARNING_MESSAGE = "This module is deprecated. Please use `{}`.".format(NEW_PATH)
|
||||||
|
|
||||||
|
|
||||||
class TestSFTPSensor(unittest.TestCase):
|
class TestMovingSFTPHookToCore(TestCase):
|
||||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
def test_move_sftp_sensor_to_core(self):
|
||||||
def test_file_present(self, sftp_hook_mock):
|
with self.assertWarns(DeprecationWarning) as warn:
|
||||||
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
|
# Reload to see deprecation warning each time
|
||||||
sftp_sensor = SFTPSensor(
|
importlib.import_module(OLD_PATH)
|
||||||
task_id='unit_test',
|
self.assertEqual(WARNING_MESSAGE, str(warn.warning))
|
||||||
path='/path/to/file/1970-01-01.txt')
|
|
||||||
context = {
|
|
||||||
'ds': '1970-01-01'
|
|
||||||
}
|
|
||||||
output = sftp_sensor.poke(context)
|
|
||||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
|
||||||
'/path/to/file/1970-01-01.txt')
|
|
||||||
self.assertTrue(output)
|
|
||||||
|
|
||||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
|
||||||
def test_file_absent(self, sftp_hook_mock):
|
|
||||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
|
||||||
SFTP_NO_SUCH_FILE, 'File missing')
|
|
||||||
sftp_sensor = SFTPSensor(
|
|
||||||
task_id='unit_test',
|
|
||||||
path='/path/to/file/1970-01-01.txt')
|
|
||||||
context = {
|
|
||||||
'ds': '1970-01-01'
|
|
||||||
}
|
|
||||||
output = sftp_sensor.poke(context)
|
|
||||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
|
||||||
'/path/to/file/1970-01-01.txt')
|
|
||||||
self.assertFalse(output)
|
|
||||||
|
|
||||||
@patch('airflow.contrib.sensors.sftp_sensor.SFTPHook')
|
|
||||||
def test_sftp_failure(self, sftp_hook_mock):
|
|
||||||
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
|
||||||
SFTP_FAILURE, 'SFTP failure')
|
|
||||||
sftp_sensor = SFTPSensor(
|
|
||||||
task_id='unit_test',
|
|
||||||
path='/path/to/file/1970-01-01.txt')
|
|
||||||
context = {
|
|
||||||
'ds': '1970-01-01'
|
|
||||||
}
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
sftp_sensor.poke(context)
|
|
||||||
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
|
||||||
'/path/to/file/1970-01-01.txt')
|
|
||||||
|
|
||||||
def test_hook_not_created_during_init(self):
|
|
||||||
sftp_sensor = SFTPSensor(
|
|
||||||
task_id='unit_test',
|
|
||||||
path='/path/to/file/1970-01-01.txt')
|
|
||||||
self.assertIsNone(sftp_sensor.hook)
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,247 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import unittest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pysftp
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
from airflow.models import Connection
|
||||||
|
from airflow.providers.sftp.hooks.sftp_hook import SFTPHook
|
||||||
|
from airflow.utils.db import provide_session
|
||||||
|
|
||||||
|
TMP_PATH = '/tmp'
|
||||||
|
TMP_DIR_FOR_TESTS = 'tests_sftp_hook_dir'
|
||||||
|
SUB_DIR = "sub_dir"
|
||||||
|
TMP_FILE_FOR_TESTS = 'test_file.txt'
|
||||||
|
|
||||||
|
SFTP_CONNECTION_USER = "root"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSFTPHook(unittest.TestCase):
|
||||||
|
|
||||||
|
@provide_session
|
||||||
|
def update_connection(self, login, session=None):
|
||||||
|
connection = (session.query(Connection).
|
||||||
|
filter(Connection.conn_id == "sftp_default")
|
||||||
|
.first())
|
||||||
|
old_login = connection.login
|
||||||
|
connection.login = login
|
||||||
|
session.commit()
|
||||||
|
return old_login
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.old_login = self.update_connection(SFTP_CONNECTION_USER)
|
||||||
|
self.hook = SFTPHook()
|
||||||
|
os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR))
|
||||||
|
|
||||||
|
with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||||
|
file.write('Test file')
|
||||||
|
with open(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS), 'a') as file:
|
||||||
|
file.write('Test file')
|
||||||
|
|
||||||
|
def test_get_conn(self):
|
||||||
|
output = self.hook.get_conn()
|
||||||
|
self.assertEqual(type(output), pysftp.Connection)
|
||||||
|
|
||||||
|
def test_close_conn(self):
|
||||||
|
self.hook.conn = self.hook.get_conn()
|
||||||
|
self.assertTrue(self.hook.conn is not None)
|
||||||
|
self.hook.close_conn()
|
||||||
|
self.assertTrue(self.hook.conn is None)
|
||||||
|
|
||||||
|
def test_describe_directory(self):
|
||||||
|
output = self.hook.describe_directory(TMP_PATH)
|
||||||
|
self.assertTrue(TMP_DIR_FOR_TESTS in output)
|
||||||
|
|
||||||
|
def test_list_directory(self):
|
||||||
|
output = self.hook.list_directory(
|
||||||
|
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertEqual(output, [SUB_DIR])
|
||||||
|
|
||||||
|
def test_create_and_delete_directory(self):
|
||||||
|
new_dir_name = 'new_dir'
|
||||||
|
self.hook.create_directory(os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||||
|
output = self.hook.describe_directory(
|
||||||
|
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertTrue(new_dir_name in output)
|
||||||
|
self.hook.delete_directory(os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_name))
|
||||||
|
output = self.hook.describe_directory(
|
||||||
|
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertTrue(new_dir_name not in output)
|
||||||
|
|
||||||
|
def test_create_and_delete_directories(self):
|
||||||
|
base_dir = "base_dir"
|
||||||
|
sub_dir = "sub_dir"
|
||||||
|
new_dir_path = os.path.join(base_dir, sub_dir)
|
||||||
|
self.hook.create_directory(os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||||
|
output = self.hook.describe_directory(
|
||||||
|
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertTrue(base_dir in output)
|
||||||
|
output = self.hook.describe_directory(
|
||||||
|
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||||
|
self.assertTrue(sub_dir in output)
|
||||||
|
self.hook.delete_directory(os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, new_dir_path))
|
||||||
|
self.hook.delete_directory(os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, base_dir))
|
||||||
|
output = self.hook.describe_directory(
|
||||||
|
os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertTrue(new_dir_path not in output)
|
||||||
|
self.assertTrue(base_dir not in output)
|
||||||
|
|
||||||
|
def test_store_retrieve_and_delete_file(self):
|
||||||
|
self.hook.store_file(
|
||||||
|
remote_full_path=os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||||
|
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||||
|
)
|
||||||
|
output = self.hook.list_directory(
|
||||||
|
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertEqual(output, [SUB_DIR, TMP_FILE_FOR_TESTS])
|
||||||
|
retrieved_file_name = 'retrieved.txt'
|
||||||
|
self.hook.retrieve_file(
|
||||||
|
remote_full_path=os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||||
|
local_full_path=os.path.join(TMP_PATH, retrieved_file_name)
|
||||||
|
)
|
||||||
|
self.assertTrue(retrieved_file_name in os.listdir(TMP_PATH))
|
||||||
|
os.remove(os.path.join(TMP_PATH, retrieved_file_name))
|
||||||
|
self.hook.delete_file(path=os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||||
|
output = self.hook.list_directory(
|
||||||
|
path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
self.assertEqual(output, [SUB_DIR])
|
||||||
|
|
||||||
|
def test_get_mod_time(self):
|
||||||
|
self.hook.store_file(
|
||||||
|
remote_full_path=os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS),
|
||||||
|
local_full_path=os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS)
|
||||||
|
)
|
||||||
|
output = self.hook.get_mod_time(path=os.path.join(
|
||||||
|
TMP_PATH, TMP_DIR_FOR_TESTS, TMP_FILE_FOR_TESTS))
|
||||||
|
self.assertEqual(len(output), 14)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_default(self, get_connection):
|
||||||
|
connection = Connection(login='login', host='host')
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, False)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_enabled(self, get_connection):
|
||||||
|
connection = Connection(
|
||||||
|
login='login', host='host',
|
||||||
|
extra='{"no_host_key_check": true}')
|
||||||
|
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, True)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_disabled(self, get_connection):
|
||||||
|
connection = Connection(
|
||||||
|
login='login', host='host',
|
||||||
|
extra='{"no_host_key_check": false}')
|
||||||
|
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, False)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_disabled_for_all_but_true(self, get_connection):
|
||||||
|
connection = Connection(
|
||||||
|
login='login', host='host',
|
||||||
|
extra='{"no_host_key_check": "foo"}')
|
||||||
|
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, False)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_ignore(self, get_connection):
|
||||||
|
connection = Connection(
|
||||||
|
login='login', host='host',
|
||||||
|
extra='{"ignore_hostkey_verification": true}')
|
||||||
|
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, True)
|
||||||
|
|
||||||
|
@mock.patch('airflow.providers.sftp.hooks.sftp_hook.SFTPHook.get_connection')
|
||||||
|
def test_no_host_key_check_no_ignore(self, get_connection):
|
||||||
|
connection = Connection(
|
||||||
|
login='login', host='host',
|
||||||
|
extra='{"ignore_hostkey_verification": false}')
|
||||||
|
|
||||||
|
get_connection.return_value = connection
|
||||||
|
hook = SFTPHook()
|
||||||
|
self.assertEqual(hook.no_host_key_check, False)
|
||||||
|
|
||||||
|
@parameterized.expand([
|
||||||
|
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True),
|
||||||
|
(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True),
|
||||||
|
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False),
|
||||||
|
(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False),
|
||||||
|
])
|
||||||
|
def test_path_exists(self, path, exists):
|
||||||
|
result = self.hook.path_exists(path)
|
||||||
|
self.assertEqual(result, exists)
|
||||||
|
|
||||||
|
@parameterized.expand([
|
||||||
|
("test/path/file.bin", None, None, True),
|
||||||
|
("test/path/file.bin", "test", None, True),
|
||||||
|
("test/path/file.bin", "test/", None, True),
|
||||||
|
("test/path/file.bin", None, "bin", True),
|
||||||
|
("test/path/file.bin", "test", "bin", True),
|
||||||
|
("test/path/file.bin", "test/", "file.bin", True),
|
||||||
|
("test/path/file.bin", None, "file.bin", True),
|
||||||
|
("test/path/file.bin", "diff", None, False),
|
||||||
|
("test/path/file.bin", "test//", None, False),
|
||||||
|
("test/path/file.bin", None, ".txt", False),
|
||||||
|
("test/path/file.bin", "diff", ".txt", False),
|
||||||
|
])
|
||||||
|
def test_path_match(self, path, prefix, delimiter, match):
|
||||||
|
result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter)
|
||||||
|
self.assertEqual(result, match)
|
||||||
|
|
||||||
|
def test_get_tree_map(self):
|
||||||
|
tree_map = self.hook.get_tree_map(path=os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
files, dirs, unknowns = tree_map
|
||||||
|
|
||||||
|
self.assertEqual(files, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR, TMP_FILE_FOR_TESTS)])
|
||||||
|
self.assertEqual(dirs, [os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)])
|
||||||
|
self.assertEqual(unknowns, [])
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
|
||||||
|
os.remove(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS))
|
||||||
|
self.update_connection(self.old_login)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,452 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from base64 import b64encode
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from airflow import AirflowException
|
||||||
|
from airflow.contrib.operators.ssh_operator import SSHOperator
|
||||||
|
from airflow.models import DAG, TaskInstance
|
||||||
|
from airflow.providers.sftp.operators.sftp_operator import SFTPOperation, SFTPOperator
|
||||||
|
from airflow.utils import timezone
|
||||||
|
from airflow.utils.timezone import datetime
|
||||||
|
from tests.test_utils.config import conf_vars
|
||||||
|
|
||||||
|
TEST_DAG_ID = 'unit_tests_sftp_op'
|
||||||
|
DEFAULT_DATE = datetime(2017, 1, 1)
|
||||||
|
TEST_CONN_ID = "conn_id_for_testing"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSFTPOperator(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
from airflow.contrib.hooks.ssh_hook import SSHHook
|
||||||
|
|
||||||
|
hook = SSHHook(ssh_conn_id='ssh_default')
|
||||||
|
hook.no_host_key_check = True
|
||||||
|
args = {
|
||||||
|
'owner': 'airflow',
|
||||||
|
'start_date': DEFAULT_DATE,
|
||||||
|
}
|
||||||
|
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
|
||||||
|
dag.schedule_interval = '@once'
|
||||||
|
self.hook = hook
|
||||||
|
self.dag = dag
|
||||||
|
self.test_dir = "/tmp"
|
||||||
|
self.test_local_dir = "/tmp/tmp2"
|
||||||
|
self.test_remote_dir = "/tmp/tmp1"
|
||||||
|
self.test_local_filename = 'test_local_file'
|
||||||
|
self.test_remote_filename = 'test_remote_file'
|
||||||
|
self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
|
||||||
|
self.test_local_filename)
|
||||||
|
# Local Filepath with Intermediate Directory
|
||||||
|
self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
|
||||||
|
self.test_local_filename)
|
||||||
|
self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
|
||||||
|
self.test_remote_filename)
|
||||||
|
# Remote Filepath with Intermediate Directory
|
||||||
|
self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
|
||||||
|
self.test_remote_filename)
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_pickle_file_transfer_put(self):
|
||||||
|
test_local_file_content = \
|
||||||
|
b"This is local file content \n which is multiline " \
|
||||||
|
b"continuing....with other character\nanother line here \n this is last line"
|
||||||
|
# create a test file locally
|
||||||
|
with open(self.test_local_filepath, 'wb') as file:
|
||||||
|
file.write(test_local_file_content)
|
||||||
|
|
||||||
|
# put test file to remote
|
||||||
|
put_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
create_intermediate_dirs=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(put_test_task)
|
||||||
|
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# check the remote file content
|
||||||
|
check_file_task = SSHOperator(
|
||||||
|
task_id="test_check_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="cat {0}".format(self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(check_file_task)
|
||||||
|
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti3.run()
|
||||||
|
self.assertEqual(
|
||||||
|
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||||
|
test_local_file_content)
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_file_transfer_no_intermediate_dir_error_put(self):
|
||||||
|
test_local_file_content = \
|
||||||
|
b"This is local file content \n which is multiline " \
|
||||||
|
b"continuing....with other character\nanother line here \n this is last line"
|
||||||
|
# create a test file locally
|
||||||
|
with open(self.test_local_filepath, 'wb') as file:
|
||||||
|
file.write(test_local_file_content)
|
||||||
|
|
||||||
|
# Try to put test file to remote
|
||||||
|
# This should raise an error with "No such file" as the directory
|
||||||
|
# does not exist
|
||||||
|
with self.assertRaises(Exception) as error:
|
||||||
|
put_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath_int_dir,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
create_intermediate_dirs=False,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(put_test_task)
|
||||||
|
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
self.assertIn('No such file', str(error.exception))
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_file_transfer_with_intermediate_dir_put(self):
|
||||||
|
test_local_file_content = \
|
||||||
|
b"This is local file content \n which is multiline " \
|
||||||
|
b"continuing....with other character\nanother line here \n this is last line"
|
||||||
|
# create a test file locally
|
||||||
|
with open(self.test_local_filepath, 'wb') as file:
|
||||||
|
file.write(test_local_file_content)
|
||||||
|
|
||||||
|
# put test file to remote
|
||||||
|
put_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath_int_dir,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
create_intermediate_dirs=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(put_test_task)
|
||||||
|
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# check the remote file content
|
||||||
|
check_file_task = SSHOperator(
|
||||||
|
task_id="test_check_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="cat {0}".format(self.test_remote_filepath_int_dir),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(check_file_task)
|
||||||
|
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti3.run()
|
||||||
|
self.assertEqual(
|
||||||
|
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||||
|
test_local_file_content)
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||||
|
def test_json_file_transfer_put(self):
|
||||||
|
test_local_file_content = \
|
||||||
|
b"This is local file content \n which is multiline " \
|
||||||
|
b"continuing....with other character\nanother line here \n this is last line"
|
||||||
|
# create a test file locally
|
||||||
|
with open(self.test_local_filepath, 'wb') as file:
|
||||||
|
file.write(test_local_file_content)
|
||||||
|
|
||||||
|
# put test file to remote
|
||||||
|
put_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(put_test_task)
|
||||||
|
ti2 = TaskInstance(task=put_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# check the remote file content
|
||||||
|
check_file_task = SSHOperator(
|
||||||
|
task_id="test_check_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="cat {0}".format(self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(check_file_task)
|
||||||
|
ti3 = TaskInstance(task=check_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti3.run()
|
||||||
|
self.assertEqual(
|
||||||
|
ti3.xcom_pull(task_ids='test_check_file', key='return_value').strip(),
|
||||||
|
b64encode(test_local_file_content).decode('utf-8'))
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_pickle_file_transfer_get(self):
|
||||||
|
test_remote_file_content = \
|
||||||
|
"This is remote file content \n which is also multiline " \
|
||||||
|
"another line here \n this is last line. EOF"
|
||||||
|
|
||||||
|
# create a test file remotely
|
||||||
|
create_file_task = SSHOperator(
|
||||||
|
task_id="test_create_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||||
|
self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(create_file_task)
|
||||||
|
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti1.run()
|
||||||
|
|
||||||
|
# get remote file to local
|
||||||
|
get_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.GET,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(get_test_task)
|
||||||
|
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# test the received content
|
||||||
|
content_received = None
|
||||||
|
with open(self.test_local_filepath, 'r') as file:
|
||||||
|
content_received = file.read()
|
||||||
|
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'False'})
|
||||||
|
def test_json_file_transfer_get(self):
|
||||||
|
test_remote_file_content = \
|
||||||
|
"This is remote file content \n which is also multiline " \
|
||||||
|
"another line here \n this is last line. EOF"
|
||||||
|
|
||||||
|
# create a test file remotely
|
||||||
|
create_file_task = SSHOperator(
|
||||||
|
task_id="test_create_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||||
|
self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(create_file_task)
|
||||||
|
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti1.run()
|
||||||
|
|
||||||
|
# get remote file to local
|
||||||
|
get_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.GET,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(get_test_task)
|
||||||
|
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# test the received content
|
||||||
|
content_received = None
|
||||||
|
with open(self.test_local_filepath, 'r') as file:
|
||||||
|
content_received = file.read()
|
||||||
|
self.assertEqual(content_received.strip(),
|
||||||
|
test_remote_file_content.encode('utf-8').decode('utf-8'))
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_file_transfer_no_intermediate_dir_error_get(self):
|
||||||
|
test_remote_file_content = \
|
||||||
|
"This is remote file content \n which is also multiline " \
|
||||||
|
"another line here \n this is last line. EOF"
|
||||||
|
|
||||||
|
# create a test file remotely
|
||||||
|
create_file_task = SSHOperator(
|
||||||
|
task_id="test_create_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||||
|
self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(create_file_task)
|
||||||
|
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti1.run()
|
||||||
|
|
||||||
|
# Try to GET test file from remote
|
||||||
|
# This should raise an error with "No such file" as the directory
|
||||||
|
# does not exist
|
||||||
|
with self.assertRaises(Exception) as error:
|
||||||
|
get_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath_int_dir,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.GET,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(get_test_task)
|
||||||
|
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
self.assertIn('No such file', str(error.exception))
|
||||||
|
|
||||||
|
@conf_vars({('core', 'enable_xcom_pickling'): 'True'})
|
||||||
|
def test_file_transfer_with_intermediate_dir_error_get(self):
|
||||||
|
test_remote_file_content = \
|
||||||
|
"This is remote file content \n which is also multiline " \
|
||||||
|
"another line here \n this is last line. EOF"
|
||||||
|
|
||||||
|
# create a test file remotely
|
||||||
|
create_file_task = SSHOperator(
|
||||||
|
task_id="test_create_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="echo '{0}' > {1}".format(test_remote_file_content,
|
||||||
|
self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(create_file_task)
|
||||||
|
ti1 = TaskInstance(task=create_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti1.run()
|
||||||
|
|
||||||
|
# get remote file to local
|
||||||
|
get_test_task = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
local_filepath=self.test_local_filepath_int_dir,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.GET,
|
||||||
|
create_intermediate_dirs=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(get_test_task)
|
||||||
|
ti2 = TaskInstance(task=get_test_task, execution_date=timezone.utcnow())
|
||||||
|
ti2.run()
|
||||||
|
|
||||||
|
# test the received content
|
||||||
|
content_received = None
|
||||||
|
with open(self.test_local_filepath_int_dir, 'r') as file:
|
||||||
|
content_received = file.read()
|
||||||
|
self.assertEqual(content_received.strip(), test_remote_file_content)
|
||||||
|
|
||||||
|
@mock.patch.dict('os.environ', {'AIRFLOW_CONN_' + TEST_CONN_ID.upper(): "ssh://test_id@localhost"})
|
||||||
|
def test_arg_checking(self):
|
||||||
|
# Exception should be raised if neither ssh_hook nor ssh_conn_id is provided
|
||||||
|
with self.assertRaisesRegex(AirflowException,
|
||||||
|
"Cannot operate without ssh_hook or ssh_conn_id."):
|
||||||
|
task_0 = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
task_0.execute(None)
|
||||||
|
|
||||||
|
# if ssh_hook is invalid/not provided, use ssh_conn_id to create SSHHook
|
||||||
|
task_1 = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook="string_rather_than_SSHHook", # invalid ssh_hook
|
||||||
|
ssh_conn_id=TEST_CONN_ID,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
task_1.execute(None)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
pass
|
||||||
|
self.assertEqual(task_1.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||||
|
|
||||||
|
task_2 = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_conn_id=TEST_CONN_ID, # no ssh_hook provided
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
task_2.execute(None)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
pass
|
||||||
|
self.assertEqual(task_2.ssh_hook.ssh_conn_id, TEST_CONN_ID)
|
||||||
|
|
||||||
|
# if both valid ssh_hook and ssh_conn_id are provided, ignore ssh_conn_id
|
||||||
|
task_3 = SFTPOperator(
|
||||||
|
task_id="test_sftp",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
ssh_conn_id=TEST_CONN_ID,
|
||||||
|
local_filepath=self.test_local_filepath,
|
||||||
|
remote_filepath=self.test_remote_filepath,
|
||||||
|
operation=SFTPOperation.PUT,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
task_3.execute(None)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
pass
|
||||||
|
self.assertEqual(task_3.ssh_hook.ssh_conn_id, self.hook.ssh_conn_id)
|
||||||
|
|
||||||
|
def delete_local_resource(self):
|
||||||
|
if os.path.exists(self.test_local_filepath):
|
||||||
|
os.remove(self.test_local_filepath)
|
||||||
|
if os.path.exists(self.test_local_filepath_int_dir):
|
||||||
|
os.remove(self.test_local_filepath_int_dir)
|
||||||
|
if os.path.exists(self.test_local_dir):
|
||||||
|
os.rmdir(self.test_local_dir)
|
||||||
|
|
||||||
|
def delete_remote_resource(self):
|
||||||
|
if os.path.exists(self.test_remote_filepath):
|
||||||
|
# check the remote file content
|
||||||
|
remove_file_task = SSHOperator(
|
||||||
|
task_id="test_check_file",
|
||||||
|
ssh_hook=self.hook,
|
||||||
|
command="rm {0}".format(self.test_remote_filepath),
|
||||||
|
do_xcom_push=True,
|
||||||
|
dag=self.dag
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(remove_file_task)
|
||||||
|
ti3 = TaskInstance(task=remove_file_task, execution_date=timezone.utcnow())
|
||||||
|
ti3.run()
|
||||||
|
if os.path.exists(self.test_remote_filepath_int_dir):
|
||||||
|
os.remove(self.test_remote_filepath_int_dir)
|
||||||
|
if os.path.exists(self.test_remote_dir):
|
||||||
|
os.rmdir(self.test_remote_dir)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.delete_local_resource()
|
||||||
|
self.delete_remote_resource()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
|
@ -0,0 +1,77 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from paramiko import SFTP_FAILURE, SFTP_NO_SUCH_FILE
|
||||||
|
|
||||||
|
from airflow.providers.sftp.sensors.sftp_sensor import SFTPSensor
|
||||||
|
|
||||||
|
|
||||||
|
class TestSFTPSensor(unittest.TestCase):
|
||||||
|
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||||
|
def test_file_present(self, sftp_hook_mock):
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.return_value = '19700101000000'
|
||||||
|
sftp_sensor = SFTPSensor(
|
||||||
|
task_id='unit_test',
|
||||||
|
path='/path/to/file/1970-01-01.txt')
|
||||||
|
context = {
|
||||||
|
'ds': '1970-01-01'
|
||||||
|
}
|
||||||
|
output = sftp_sensor.poke(context)
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||||
|
'/path/to/file/1970-01-01.txt')
|
||||||
|
self.assertTrue(output)
|
||||||
|
|
||||||
|
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||||
|
def test_file_absent(self, sftp_hook_mock):
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||||
|
SFTP_NO_SUCH_FILE, 'File missing')
|
||||||
|
sftp_sensor = SFTPSensor(
|
||||||
|
task_id='unit_test',
|
||||||
|
path='/path/to/file/1970-01-01.txt')
|
||||||
|
context = {
|
||||||
|
'ds': '1970-01-01'
|
||||||
|
}
|
||||||
|
output = sftp_sensor.poke(context)
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||||
|
'/path/to/file/1970-01-01.txt')
|
||||||
|
self.assertFalse(output)
|
||||||
|
|
||||||
|
@patch('airflow.providers.sftp.sensors.sftp_sensor.SFTPHook')
|
||||||
|
def test_sftp_failure(self, sftp_hook_mock):
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.side_effect = OSError(
|
||||||
|
SFTP_FAILURE, 'SFTP failure')
|
||||||
|
sftp_sensor = SFTPSensor(
|
||||||
|
task_id='unit_test',
|
||||||
|
path='/path/to/file/1970-01-01.txt')
|
||||||
|
context = {
|
||||||
|
'ds': '1970-01-01'
|
||||||
|
}
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
sftp_sensor.poke(context)
|
||||||
|
sftp_hook_mock.return_value.get_mod_time.assert_called_once_with(
|
||||||
|
'/path/to/file/1970-01-01.txt')
|
||||||
|
|
||||||
|
def test_hook_not_created_during_init(self):
|
||||||
|
sftp_sensor = SFTPSensor(
|
||||||
|
task_id='unit_test',
|
||||||
|
path='/path/to/file/1970-01-01.txt')
|
||||||
|
self.assertIsNone(sftp_sensor.hook)
|
Загрузка…
Ссылка в новой задаче